diff --git a/t5x-main/.github/workflows/build.yaml b/t5x-main/.github/workflows/build.yaml deleted file mode 100644 index 2f94a6a4803370721953d967f438051a02bff461..0000000000000000000000000000000000000000 --- a/t5x-main/.github/workflows/build.yaml +++ /dev/null @@ -1,39 +0,0 @@ -name: build - -on: [push] - -jobs: - build: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v2 - - name: Set up Python - uses: actions/setup-python@v4 - with: - python-version: '3.10.x' - cache: 'pip' - cache-dependency-path: setup.py - - name: Install dependencies - run: | - pip install -e .[test] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html - - name: Test with pytest - run: | - pytest - # The below step just reports the success or failure of tests as a "commit status". - # This is needed for copybara integration. - - name: Report success or failure as github status - if: always() - shell: bash - run: | - status="${{ job.status }}" - lowercase_status=$(echo $status | tr '[:upper:]' '[:lower:]') - curl -sS --request POST \ - --url https://api.github.com/repos/${{ github.repository }}/statuses/${{ github.sha }} \ - --header 'authorization: Bearer ${{ secrets.GITHUB_TOKEN }}' \ - --header 'content-type: application/json' \ - --data '{ - "state": "'$lowercase_status'", - "target_url": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}", - "description": "'$status'", - "context": "github-actions/build" - }' diff --git a/t5x-main/CONTRIBUTING.md b/t5x-main/CONTRIBUTING.md deleted file mode 100644 index 8cc085b5d1361202bc8456a5cefd565075bd59ee..0000000000000000000000000000000000000000 --- a/t5x-main/CONTRIBUTING.md +++ /dev/null @@ -1 +0,0 @@ -External contributions are not accepted, sorry! diff --git a/t5x-main/LICENSE b/t5x-main/LICENSE deleted file mode 100644 index d645695673349e3947e8e5ae42332d0ac3164cd7..0000000000000000000000000000000000000000 --- a/t5x-main/LICENSE +++ /dev/null @@ -1,202 +0,0 @@ - - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. diff --git a/t5x-main/README.md b/t5x-main/README.md deleted file mode 100644 index 916b684b8364a16d4ca35d21db0a7487b9bcb413..0000000000000000000000000000000000000000 --- a/t5x-main/README.md +++ /dev/null @@ -1,525 +0,0 @@ -# T5X - -*Go to [T5X ReadTheDocs Documentation Page](https://t5x.readthedocs.io/).* - -T5X is a modular, composable, research-friendly framework for high-performance, -configurable, self-service training, evaluation, and inference of sequence -models (starting with language) at many scales. - -It is essentially a new and improved implementation of the -[T5 codebase](https://github.com/google-research/text-to-text-transfer-transformer) -(based on [Mesh TensorFlow](https://github.com/tensorflow/mesh)) in [JAX](https://github.com/google/jax) and [Flax](https://github.com/google/flax). To learn -more, see the [T5X Paper](https://arxiv.org/abs/2203.17189). - -Below is a quick start guide for training models with TPUs on Google Cloud. For -additional tutorials and background, see the [complete documentation](docs/index.md). - -## Quickstart (Recommended) - -T5X can be run with [XManager](https://github.com/deepmind/xmanager) on -[Vertex AI](https://cloud.google.com/vertex-ai). Vertex AI is a platform for -training that creates TPU instances and runs code on the TPUs. Vertex AI will -also shut down the TPUs when the jobs terminate. This is signifcantly easier -than managing GCE VMs and TPU VM instances. - -1. Follow the pre-requisites and directions to install [XManager](https://github.com/deepmind/xmanager). - -2. Request TPU quota as required. GCP projects come with 8 cores by default, -which is enough to run one training experiment on a single TPU host. If you want -to run multi-host training or run multiple trials in parallel, you will need -more quota. Navigate to [Quotas](https://console.cloud.google.com/quotas). - - The quota you want is: - - * Service: `Vertex AI API` - * Dimensions (location): `us-central1` - * If you want to run single-host experiments: - * `Custom model training TPU V2 cores per region` - * `Custom model training TPU V3 cores per region` - * If you want to run multi-host experiments: - * `Custom model training TPU V2 pod cores per region` - * `Custom model training TPU V3 pod cores per region` - - TIP: You won't be able to run single-host experiments with multi-host quota. - (i.e. you can't run `tpu_v2=8` using `TPU V2 pod`) - - -3. Launch the xmanager script located at `t5x/scripts/xm_launch.py`. - -As a running example, we use the WMT14 En-De translation which is described in -more detail in the Examples section below. - -```sh -export GOOGLE_CLOUD_BUCKET_NAME=... -export TFDS_DATA_DIR=gs://$GOOGLE_CLOUD_BUCKET_NAME/t5x/data -export MODEL_DIR=gs://$GOOGLE_CLOUD_BUCKET_NAME/t5x/$(date +%Y%m%d) - -# Pre-download dataset in multi-host experiments. -tfds build wmt_t2t_translate --data_dir=$TFDS_DATA_DIR - -git clone https://github.com/google-research/t5x -cd ./t5x/ - -python3 ./t5x/scripts/xm_launch.py \ - --gin_file=t5x/examples/t5/t5_1_1/examples/base_wmt_from_scratch.gin \ - --model_dir=$MODEL_DIR \ - --tfds_data_dir=$TFDS_DATA_DIR -``` - -Check `gs://$GOOGLE_CLOUD_BUCKET_NAME/t5x/` for the output artifacts, which can -be read by TensorBoard. - -## GPU Usage -Note: NVIDIA has released an updated version of this repository with H100 FP8 support and broad GPU performance improvements. Please visit the [NVIDIA Rosetta](https://github.com/NVIDIA/JAX-Toolbox/tree/main/rosetta/rosetta/projects/t5x) repository for more details and usage instructions. - -T5X can be run easily on GPUs either in single-node configurations or multi-node configurations with a SLURM+pyxis cluster. Further instructions at [t5x/contrib/gpu](https://github.com/google-research/t5x/blob/main/t5x/contrib/gpu/README.md). The `t5x/contrib/gpu/scripts_gpu` folder contains example scripts for pretraining T5X on [The Pile](https://pile.eleuther.ai/) and for finetuning on SQuAD and MNLI. These scripts and associated `gin` configurations also contain additional GPU optimizations for better throughput. More examples and instructions can be found in the [NVIDIA Rosetta](https://github.com/NVIDIA/JAX-Toolbox/tree/main/rosetta/rosetta/projects/t5x) repository maintained by NVIDIA with H100 FP8 support and broad GPU performance improvements. - - -## Installation - -Note that all the commands in this document should be run in the commandline of -the TPU VM instance unless otherwise stated. - -1. Follow the - [instructions](https://cloud.google.com/tpu/docs/jax-quickstart-tpu-vm#install_the_google_cloud_sdk) - to set up a Google Cloud Platform (GCP) account and enable the Cloud TPU - API. - - **Note:** T5X also works with GPU, please follow instructions in [t5x/contrib/gpu](https://github.com/google-research/t5x/blob/main/t5x/contrib/gpu/README.md) if you'd like to use GPU version. - -2. Create a - [Cloud TPU VM instance](https://cloud.google.com/blog/products/compute/introducing-cloud-tpu-vms) - following - [this instruction](https://cloud.google.com/tpu/docs/jax-quickstart-tpu-vm#create-vm). - We recommend that you develop your workflow in a single v3-8 TPU (i.e., - `--accelerator-type=v3-8`) and scale up to pod slices once the pipeline is - ready. In this README, we focus on using a single v3-8 TPU. See - [here](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm) to - learn more about TPU architectures. - -3. With Cloud TPU VMs, you ssh directly into the host machine of the TPU VM. - You can install packages, run your code run, etc. in the host machine. Once - the TPU instance is created, ssh into it with - - ```sh - gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --zone=${ZONE} - ``` - - where `TPU_NAME` and `ZONE` are the name and the zone used in step 2. - -4. Install T5X and the dependencies. - - ```sh - git clone --branch=main https://github.com/google-research/t5x - cd t5x - - python3 -m pip install -e '.[tpu]' -f \ - https://storage.googleapis.com/jax-releases/libtpu_releases.html - - ``` - - -5. Create Google Cloud Storage (GCS) bucket to store the dataset and model - checkpoints. To create a GCS bucket, see these - [instructions](https://cloud.google.com/storage/docs/creating-buckets). - -6. (optional) If you prefer working with Jupyter/Colab style environment - you can setup a custom Colab runtime by following steps from - [t5x/notebooks](https://github.com/google-research/t5x/blob/main/t5x/notebooks/README.md). - -## Example: English to German translation - -As a running example, we use the WMT14 En-De translation. The raw dataset is -available in TensorFlow Datasets as -["wmt_t2t_translate"](https://www.tensorflow.org/datasets/catalog/wmt_t2t_translate). - -T5 casts the translation task such as the following - -```py -{'en': 'That is good.', 'de': 'Das ist gut.'} -``` - -to the form called "text-to-text": - -```py -{'inputs': 'translate English to German: That is good.', 'targets': 'Das ist gut.'} -``` - -This formulation allows many different classes of language tasks to be expressed -in a uniform manner and a single encoder-decoder architecture can handle them -without any task-specific parameters. For more detail, refer to the [T5 paper -(Raffel et al. 2019)][t5_paper]. - -For a scalable data pipeline and an evaluation framework, we use -[`SeqIO`](https://github.com/google/seqio), which was factored out of the [T5 -library][t5_github]. A `seqio.Task` packages together the raw dataset, vocabulary, -preprocessing such as tokenization and evaluation metrics such as -[BLEU](https://aclanthology.org/P02-1040.pdf) and provides a -[`tf.data`](https://www.tensorflow.org/guide/data) instance. - -[The T5 library][t5_github] provides a number of `seqio.Task`s that were used in the -[T5 paper][t5_paper]. In this example, we use [wmt_t2t_ende_v003](https://github.com/google-research/text-to-text-transfer-transformer/blob/d81c0bab2a41b4d5dfbe4971de32f7d67df65f31/t5/data/tasks.py#L212). - -Before training or fine-tuning you need to download ["wmt_t2t_translate"] -(https://www.tensorflow.org/datasets/catalog/wmt_t2t_translate) dataset first. - -```sh -# Data dir to save the processed dataset in "gs://data_dir" format. -TFDS_DATA_DIR="..." - -# Make sure that dataset package is up-to-date. -python3 -m pip install --upgrade tfds-nightly - -# Pre-download dataset. -tfds build wmt_t2t_translate ${TFDS_DATA_DIR} -``` - -### Training - -To run a training job, we use the `t5x/train.py` script. - -```sh -# Model dir to save logs, ckpts, etc. in "gs://model_dir" format. -MODEL_DIR="..." -T5X_DIR="..." # directory where the T5X repo is cloned. -TFDS_DATA_DIR="..." - -python3 ${T5X_DIR}/t5x/train.py \ - --gin_file="t5x/examples/t5/t5_1_1/examples/base_wmt_from_scratch.gin" \ - --gin.MODEL_DIR=\"${MODEL_DIR}\" \ - --tfds_data_dir=${TFDS_DATA_DIR} -``` - -The configuration for this training run is defined in the Gin file -[base_wmt_from_scratch.gin](t5x/examples/t5/t5_1_1/examples/base_wmt_from_scratch.gin). -[Gin-config](https://github.com/google/gin-config) is a library to handle -configurations based on dependency injection. Among many benefits, Gin allows -users to pass custom components such as a custom model to the T5X library -without having to modify the core library. The [custom -components](#custom-components) section shows how this is done. - -While the core library is independent of Gin, it is central to the examples we -provide. Therefore, we provide a short [introduction][gin-primer] to Gin in the -context of T5X. All the configurations are written to a file "config.gin" in -`MODEL_DIR`. This makes debugging as well as reproducing the experiment much -easier. - -In addition to the `config.json`, `model-info.txt` file summarizes the model -parameters (shape, names of the axes, partitioning info) as well as the -optimizer states. - - - -#### TensorBoard - -To monitor the training in [TensorBoard](https://www.tensorflow.org/tensorboard), it is much easier (due to -authentification issues) to launch the TensorBoard on your own machine and _not_ in -the TPU VM. So in the commandline where you ssh'ed into the TPU VM, launch the -TensorBoard with the `logdir` pointing to the `MODEL_DIR`. - -```sh -# NB: run this on your machine not TPU VM! -MODEL_DIR="..." # Copy from the TPU VM. -tensorboard --logdir=${MODEL_DIR} -``` - -Or you can launch the TensorBoard inside a Colab. In a Colab cell, run - -```python -from google.colab import auth -auth.authenticate_user() -``` - -to authorize the Colab to access the GCS bucket and launch the TensorBoard. - -```python -%load_ext tensorboard -model_dir = "..." # Copy from the TPU VM. -%tensorboard --logdir=model_dir -``` - - -### Fine-tuning - -We can leverage the benefits of self-supervised pre-training by initializing -from one of our pre-trained models. Here we use the T5.1.1 Base checkpoint. - -```sh -# Model dir to save logs, ckpts, etc. in "gs://model_dir" format. -MODEL_DIR="..." - -# Data dir to save the processed dataset in "gs://data_dir" format. -TFDS_DATA_DIR="..." -T5X_DIR="..." # directory where the T5X repo is cloned. - -python3 ${T5X_DIR}/t5x/train.py \ - --gin_file="t5x/examples/t5/t5_1_1/examples/base_wmt_finetune.gin" \ - --gin.MODEL_DIR=\"${MODEL_DIR}\" \ - --tfds_data_dir=${TFDS_DATA_DIR} -``` - -**Note:** when supplying a string, dict, list, tuple value, or a bash variable -via a flag, you must put it in quotes. In the case of strings, it requires -escaped quotes (`\"\"`). For example: -`--gin.utils.DatasetConfig.split=\"validation\"` or -`--gin.MODEL_DIR=\"${MODEL_DIR}\"`. - -Gin makes it easy to change a number of configurations. For example, you can -change the `partitioning.PjitPartitioner.num_partitions` (overriding -the value in -[base_wmt_from_scratch.gin](t5x/examples/t5/t5_1_1/examples/base_wmt_from_scratch.gin)) -to chanage the parallelism strategy and pass it as a commandline arg. - -```sh ---gin.partitioning.PjitPartitioner.num_partitions=8 -``` - -### Evaluation - -To run the offline (i.e. without training) evaluation, you can use `t5x/eval.py` -script. - -```sh -EVAL_OUTPUT_DIR="..." # directory to write eval output -T5X_DIR="..." # directory where the t5x is cloned, e.g., ${HOME}"/t5x". -TFDS_DATA_DIR="..." -CHECKPOINT_PATH="..." - -python3 ${T5X_DIR}/t5x/eval.py \ - --gin_file="t5x/examples/t5/t5_1_1/examples/base_wmt_eval.gin" \ - --gin.CHECKPOINT_PATH=\"${CHECKPOINT_PATH}\" \ - --gin.EVAL_OUTPUT_DIR=\"${EVAL_OUTPUT_DIR}\" \ - --tfds_data_dir=${TFDS_DATA_DIR} -``` - - -### Inference - -To run inference, you can use `t5x/infer.py` script. Here we use the same -`seqio.Task`, but for inference we do not use the targets features other than -logging them alongside the prediction in a JSON file. - -```sh -INFER_OUTPUT_DIR="..." # directory to write infer output -T5X_DIR="..." # directory where the t5x is cloned, e.g., ${HOME}"/t5x". -TFDS_DATA_DIR="..." -CHECKPOINT_PATH="..." - -python3 ${T5X_DIR}/t5x/infer.py \ - --gin_file="t5x/examples/t5/t5_1_1/examples/base_wmt_infer.gin" \ - --gin.CHECKPOINT_PATH=\"${CHECKPOINT_PATH}\" \ - --gin.INFER_OUTPUT_DIR=\"${INFER_OUTPUT_DIR}\" \ - --tfds_data_dir=${TFDS_DATA_DIR} -``` - -### Exporting as TensorFlow Saved Model - -Pretrained model can be exported as TensorFlow Saved Model, and deployed -to Vertex AI Prediction service using [Optimized TensorFlow Runtime] -(https://cloud.google.com/vertex-ai/docs/predictions/optimized-tensorflow-runtime). -Please note that exported model won't work on OSS based -[TensorFlow Model Server](https://github.com/tensorflow/serving). - -```sh -T5X_DIR="..." # directory where the t5x is cloned, e.g., ${HOME}"/t5x". -CHECKPOINT_PATH="..." - -BATCH_SIZE=None -BEAM_SIZE=1 - -# Use 'bfloat16' if you plan to run exported model on NVIDIA A100 or newer GPUs, -# for other GPUs use 'float32'. -ACTIVATION_DTYPE=bfloat16 - -# Version numbers must be numeric. We generate one based on datetime. -VERSION=$(date +%Y%m%d%H%M%S) - -NAME=t5x_base_${ACTIVATION_DTYPE} # Model name. - -# Path to export model to. Note that export script is going to add _cpu suffix -# after model name. -OUTPUT=${CHECKPOINT_PATH}/saved_model.${NAME}/${VERSION} - -declare -a ARGS=( ---gin_file=t5x/examples/t5/t5_1_1/base.gin ---gin_file=t5x/t5x/configs/runs/export.gin ---gin.TASK_FEATURE_LENGTHS="{'inputs': 256, 'targets': 256}" ---gin.CHECKPOINT_PATH=\"${CHECKPOINT_PATH}\" ---gin.MODEL_NAME=\"/ml/${USER}/t5x_base\" ---gin.MODEL_OUTPUT_DIR=\"${OUTPUT}\" ---gin.BEAM_SIZE=${BEAM_SIZE} ---gin.BATCH_SIZE=${BATCH_SIZE} ---gin.export_lib.save.partitioner=None ---gin.export_lib.save.warmup_examples="['hello world']" ---gin.export_lib.ExportableModule.use_batch_function=False ---gin.export_lib.ExportableModule.use_gpu=False ---gin.export_lib.ExportableModule.jit_compile=False ---gin.ACTIVATION_DTYPE=\"${ACTIVATION_DTYPE}\" ---gin.network.T5Config.dtype=\"${ACTIVATION_DTYPE}\" ---gin.utils.RestoreCheckpointConfig.dtype=\"${ACTIVATION_DTYPE}\" ---gin.DROPOUT_RATE=0.0 -) - -(python3 ${T5X_DIR}/t5x/export.py "${ARGS[@]}") -``` - -For detailed arguments definition refer to [export.gin] -(t5x/configs/runs/export.gin). - -You can run XL and smaller models on NVIDIA A100 40GB, and XXL models on -NVIDIA A100 80GB. - -## Custom components - -[The translation example](#example-english-to-german-translation) uses the -encoder-decoder model that T5X provides as well as the dataset from the T5 -library. This section shows how you can use your own dataset and a model and -pass via Gin. - -### Example: custom dataset in a user directory - -For this example, we have the following directory structure with -`${HOME}/dir1/user_dir` representing a user directory with custom components. - -``` -${HOME} -└── dir1 -    └── user_dir -    ├── t5_1_1_base_de_en.gin -    └── tasks.py -``` - -As an example, let's define a new dataset. Here we use the same Translation -dataset but we define the translation task in the opposite direction, i.e., -German to English intead of English to German. We define this task in `tasks.py` - -```py -# ${HOME}/dir1/user_dir/tasks.py - -import functools -import seqio -import tensorflow_datasets as tfds -from t5.evaluation import metrics -from t5.data import preprocessors - -vocabulary = seqio.SentencePieceVocabulary( - 'gs://t5-data/vocabs/cc_all.32000/sentencepiece.model', extra_ids=100) -output_features = { - 'inputs': seqio.Feature(vocabulary=vocabulary), - 'targets': seqio.Feature(vocabulary=vocabulary) -} - -seqio.TaskRegistry.add( - 'wmt_t2t_de_en_v003', - source=seqio.TfdsDataSource(tfds_name='wmt_t2t_translate/de-en:1.0.0'), - preprocessors=[ - functools.partial( - preprocessors.translate, - source_language='de', target_language='en'), - seqio.preprocessors.tokenize, - seqio.CacheDatasetPlaceholder(), - seqio.preprocessors.append_eos_after_trim, - ], - metric_fns=[metrics.bleu], - output_features=output_features) -``` - -In the Gin file, most of the settings are equivalent to those used in the -[En->De example](#example-english-to-german-translation). So we include the Gin -file from that example. To use "wmt_t2t_de_en_v003" task we just defined, we -need to import the task module "tasks.py". Note that we use a relative path -defined with respect to the user directory. This will be specified as a -flag. - -```py -# ${HOME}/dir1/user_dir/t5_1_1_base_de_en.gin -from __gin__ import dynamic_registration -import tasks # This imports the task defined in dir1/user_dir/tasks.py. - -include "t5x-tmp/t5x/examples/t5/t5_1_1/examples/base_wmt_from_scratch.gin" -MIXTURE_OR_TASK_NAME = "wmt_t2t_de_en_v003" -``` - -Finally, we launch training passing the user directory as a flag -`gin_search_paths` such that the Gin file and python modules can be specified -with relative paths. - -```sh -PROJECT_DIR=${HOME}"/dir1/user_dir" -T5X_DIR="..." # directory where the t5x is cloned. -TFDS_DATA_DIR="..." -MODEL_DIR="..." -export PYTHONPATH=${PROJECT_DIR} - -python3 ${T5X_DIR}/t5x/train.py \ - --gin_search_paths=${PROJECT_DIR} \ - --gin_file="t5_1_1_base_de_en.gin" \ - --gin.MODEL_DIR=\"${MODEL_DIR}\" \ - --tfds_data_dir=${TFDS_DATA_DIR} -``` - -## Checkpoints - -### Native Checkpoints - -We have released the checkpoints of many of the original T5 models and their -variants a native T5X format for maximal efficiency. -See the [complete list](https://github.com/google-research/t5x/blob/main/docs/models.md) including the -matching Gin configuration files. - -These are converted from the public [Mesh TensorFlow -checkpoints](https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md#t511) -. - - -### Compatibility with the Mesh TensorFlow checkpoints -The Mesh TensorFlow checkpoints trained using the [T5 library][t5_github] can be -directly loaded into T5X. For example, we can rerun the fine-tuning example -initializing from the MTF checkpoint by changing the `INIT_CHECKPOINT` Gin -macro. - -```sh -# Model dir to save logs, ckpts, etc. in "gs://model_dir" format. -MODEL_DIR="..." - -# Data dir to save the processed dataset in "gs://data_dir" format. -TFDS_DATA_DIR="..." -T5X_DIR="..." # directory where the T5X repo is cloned. - -python3 ${T5X_DIR}/t5x/train.py \ - --gin_file="t5x/examples/t5/t5_1_1/examples/base_wmt19_ende_train.gin" \ - --gin.MODEL_DIR=\"${MODEL_DIR}\" \ - --gin.MIXTURE_OR_TASK_NAME=\"wmt_t2t_ende_v003\" \ - --gin.INIT_CHECKPOINT=\"gs://t5-data/pretrained_models/t5.1.1.base/model.ckpt-1000000\" \ - --tfds_data_dir=${TFDS_DATA_DIR} -``` - -Note that restoring directly from the Mesh TensorFlow checkpoints can be -inefficient if heavy model parallelism is used for large models. This is -because each host loads the entire copy of the model first and then keep only -the relevant slices dictated by the model parallelism specification. If you have -Mesh TensorFlow checkpoints that you run often, we recommend converting the -checkpoints to T5X native format using the -[convert_tf_checkpoint script](t5x/scripts/convert_tf_checkpoint.py). - - -## Citing T5X -Please use the following bibtex entry to cite T5X. - -``` -@article{roberts2022t5x, - url = {https://arxiv.org/abs/2203.17189}, - author = {Roberts, Adam and Chung, Hyung Won and Levskaya, Anselm and Mishra, Gaurav and Bradbury, James and Andor, Daniel and Narang, Sharan and Lester, Brian and Gaffney, Colin and Mohiuddin, Afroz and Hawthorne, Curtis and Lewkowycz, Aitor and Salcianu, Alex and van Zee, Marc and Austin, Jacob and Goodman, Sebastian and Soares, Livio Baldini and Hu, Haitang and Tsvyashchenko, Sasha and Chowdhery, Aakanksha and Bastings, Jasmijn and Bulian, Jannis and Garcia, Xavier and Ni, Jianmo and Chen, Andrew and Kenealy, Kathleen and Clark, Jonathan H. and Lee, Stephan and Garrette, Dan and Lee-Thorp, James and Raffel, Colin and Shazeer, Noam and Ritter, Marvin and Bosma, Maarten and Passos, Alexandre and Maitin-Shepard, Jeremy and Fiedel, Noah and Omernick, Mark and Saeta, Brennan and Sepassi, Ryan and Spiridonov, Alexander and Newlan, Joshua and Gesmundo, Andrea}, - title = {Scaling Up Models and Data with $\texttt{t5x}$ and $\texttt{seqio}$}, - journal={arXiv preprint arXiv:2203.17189}, - year = {2022}, -} -``` - - -## Note -This is not an officially supported Google product - -[t5_paper]: https://arxiv.org/abs/1910.10683 -[t5_github]: https://github.com/google-research/text-to-text-transfer-transformer -[gin-primer]: docs/usage/gin.md diff --git a/t5x-main/docs/_static/t5x_theme.css b/t5x-main/docs/_static/t5x_theme.css deleted file mode 100644 index 9a820ccda8c8bdd9395fe40947e747a6b68fb3b7..0000000000000000000000000000000000000000 --- a/t5x-main/docs/_static/t5x_theme.css +++ /dev/null @@ -1,23 +0,0 @@ -@import url("theme.css"); - -.wy-nav-content { - max-width: 1290px; -} - -.rst-content table.docutils { - width: 100%; -} - -.rst-content table.docutils td { - vertical-align: top; - padding: 0; -} - -.rst-content table.docutils td p { - padding: 8px; -} - -.rst-content div[class^=highlight] { - border: 0; - margin: 0; -} \ No newline at end of file diff --git a/t5x-main/docs/_templates/autosummary/t5x_module.rst b/t5x-main/docs/_templates/autosummary/t5x_module.rst deleted file mode 100644 index 5f51933cff67130130bc49a0345b28c727188e97..0000000000000000000000000000000000000000 --- a/t5x-main/docs/_templates/autosummary/t5x_module.rst +++ /dev/null @@ -1,23 +0,0 @@ -{{ fullname | escape | underline}} - -.. currentmodule:: {{ module }} - -.. autoclass:: {{ objname }} - :exclude-members: - - {% block methods %} - - .. automethod:: __call__ - - {% if methods %} - .. rubric:: Methods - - .. autosummary:: - - {% for item in methods %} - {%- if item not in inherited_members and item not in annotations and not item in ['__init__'] %} - ~{{ name }}.{{ item }} - {%- endif %} - {%- endfor %} - {% endif %} - {% endblock %} \ No newline at end of file diff --git a/t5x-main/docs/api_reference/index.rst b/t5x-main/docs/api_reference/index.rst deleted file mode 100644 index 6af8449d0c0746c0533550c7c2a8d63076026ab0..0000000000000000000000000000000000000000 --- a/t5x-main/docs/api_reference/index.rst +++ /dev/null @@ -1,100 +0,0 @@ -API Reference -============= - -Binaries --------- - -.. toctree:: - :maxdepth: 3 - - t5x.train - t5x.infer - t5x.eval - t5x.main - -Training ---------- - -.. toctree:: - :maxdepth: 3 - - t5x.trainer - t5x.optimizers - t5x.interactive_model - t5x.train_state - t5x.state_utils - t5x.losses - t5x.metrics - t5x.utils - t5x.adafactor - -Inference ---------- - -.. toctree:: - :maxdepth: 3 - - t5x.decoding - -Models ------- - -.. toctree:: - :maxdepth: 3 - - t5x.models - -Checkpointing -------------- - -.. toctree:: - :maxdepth: 3 - - t5x.checkpoints - t5x.checkpoint_utils - t5x.checkpoint_importer - - -Paritioning ------------ - -.. toctree:: - :maxdepth: 3 - - t5x.partitioning - -Config ------- - -.. toctree:: - :maxdepth: 3 - - t5x.config_utils - t5x.gin_utils - -Utils ------ - -.. toctree:: - :maxdepth: 3 - - t5x.test_utils - t5x.binary_search - - - - - - - - - - - - - - - - - - diff --git a/t5x-main/docs/api_reference/t5x.adafactor.rst b/t5x-main/docs/api_reference/t5x.adafactor.rst deleted file mode 100644 index 1d22cf2744be4a7fc4b4d69f79728b4fae09a80a..0000000000000000000000000000000000000000 --- a/t5x-main/docs/api_reference/t5x.adafactor.rst +++ /dev/null @@ -1,7 +0,0 @@ -t5x.adafactor package -======================== - -.. currentmodule:: t5x.adafactor - -.. automodule:: t5x.adafactor - :members: diff --git a/t5x-main/docs/api_reference/t5x.binary_search.rst b/t5x-main/docs/api_reference/t5x.binary_search.rst deleted file mode 100644 index c96d9ca444423ed545847f6f688ebb9878f280c7..0000000000000000000000000000000000000000 --- a/t5x-main/docs/api_reference/t5x.binary_search.rst +++ /dev/null @@ -1,7 +0,0 @@ -t5x.binary_search package -======================== - -.. currentmodule:: t5x.binary_search - -.. automodule:: t5x.binary_search - :members: diff --git a/t5x-main/docs/api_reference/t5x.checkpoint_importer.rst b/t5x-main/docs/api_reference/t5x.checkpoint_importer.rst deleted file mode 100644 index d7ceaefbcf4773d9ee3e175a1bef927f92b38659..0000000000000000000000000000000000000000 --- a/t5x-main/docs/api_reference/t5x.checkpoint_importer.rst +++ /dev/null @@ -1,7 +0,0 @@ -t5x.checkpoint_importer package -======================== - -.. currentmodule:: t5x.checkpoint_importer - -.. automodule:: t5x.checkpoint_importer - :members: diff --git a/t5x-main/docs/api_reference/t5x.checkpoint_utils.rst b/t5x-main/docs/api_reference/t5x.checkpoint_utils.rst deleted file mode 100644 index cb5dcff6a3f76a83dcc0337fa51d5ccf72776a29..0000000000000000000000000000000000000000 --- a/t5x-main/docs/api_reference/t5x.checkpoint_utils.rst +++ /dev/null @@ -1,7 +0,0 @@ -t5x.checkpoint_utils package -======================== - -.. currentmodule:: t5x.checkpoint_utils - -.. automodule:: t5x.checkpoint_utils - :members: diff --git a/t5x-main/docs/api_reference/t5x.checkpoints.rst b/t5x-main/docs/api_reference/t5x.checkpoints.rst deleted file mode 100644 index ed80aef1ed1d914397aa9dbf0103ed7f63819ed9..0000000000000000000000000000000000000000 --- a/t5x-main/docs/api_reference/t5x.checkpoints.rst +++ /dev/null @@ -1,7 +0,0 @@ -t5x.checkpoints package -======================== - -.. currentmodule:: t5x.checkpoints - -.. automodule:: t5x.checkpoints - :members: diff --git a/t5x-main/docs/api_reference/t5x.config_utils.rst b/t5x-main/docs/api_reference/t5x.config_utils.rst deleted file mode 100644 index 72994475b27387160e5d63c07f914c4175b7593d..0000000000000000000000000000000000000000 --- a/t5x-main/docs/api_reference/t5x.config_utils.rst +++ /dev/null @@ -1,7 +0,0 @@ -t5x.config_utils package -======================== - -.. currentmodule:: t5x.config_utils - -.. automodule:: t5x.config_utils - :members: diff --git a/t5x-main/docs/api_reference/t5x.decoding.rst b/t5x-main/docs/api_reference/t5x.decoding.rst deleted file mode 100644 index 2f3cfbdde1b7efd526a343b4fb5e153ad12c9444..0000000000000000000000000000000000000000 --- a/t5x-main/docs/api_reference/t5x.decoding.rst +++ /dev/null @@ -1,7 +0,0 @@ -t5x.decoding package -======================== - -.. currentmodule:: t5x.decoding - -.. automodule:: t5x.decoding - :members: diff --git a/t5x-main/docs/api_reference/t5x.eval.rst b/t5x-main/docs/api_reference/t5x.eval.rst deleted file mode 100644 index 55b3e22587997ffcf185ee7afdb49dd1f495d73c..0000000000000000000000000000000000000000 --- a/t5x-main/docs/api_reference/t5x.eval.rst +++ /dev/null @@ -1,7 +0,0 @@ -t5x.eval binary -======================== - -.. currentmodule:: t5x.eval - -.. automodule:: t5x.eval - :members: diff --git a/t5x-main/docs/api_reference/t5x.gin_utils.rst b/t5x-main/docs/api_reference/t5x.gin_utils.rst deleted file mode 100644 index 37d12b8e6b574c0b76d617cba0cd95d3010f7838..0000000000000000000000000000000000000000 --- a/t5x-main/docs/api_reference/t5x.gin_utils.rst +++ /dev/null @@ -1,7 +0,0 @@ -t5x.gin_utils package -======================== - -.. currentmodule:: t5x.gin_utils - -.. automodule:: t5x.gin_utils - :members: diff --git a/t5x-main/docs/api_reference/t5x.infer.rst b/t5x-main/docs/api_reference/t5x.infer.rst deleted file mode 100644 index a6c877497b485f621447ce7b6936fc70b621839e..0000000000000000000000000000000000000000 --- a/t5x-main/docs/api_reference/t5x.infer.rst +++ /dev/null @@ -1,7 +0,0 @@ -t5x.infer binary -======================== - -.. currentmodule:: t5x.infer - -.. automodule:: t5x.infer - :members: diff --git a/t5x-main/docs/api_reference/t5x.interactive_model.rst b/t5x-main/docs/api_reference/t5x.interactive_model.rst deleted file mode 100644 index 66e51a6dd0ebd0ead10bb77150e0a3a1e45bd4a4..0000000000000000000000000000000000000000 --- a/t5x-main/docs/api_reference/t5x.interactive_model.rst +++ /dev/null @@ -1,7 +0,0 @@ -t5x.interactive_model package -======================== - -.. currentmodule:: t5x.interactive_model - -.. automodule:: t5x.interactive_model - :members: diff --git a/t5x-main/docs/api_reference/t5x.losses.rst b/t5x-main/docs/api_reference/t5x.losses.rst deleted file mode 100644 index c8f27a4c7230455eb1d18371fca659712238cc2a..0000000000000000000000000000000000000000 --- a/t5x-main/docs/api_reference/t5x.losses.rst +++ /dev/null @@ -1,7 +0,0 @@ -t5x.losses package -======================== - -.. currentmodule:: t5x.losses - -.. automodule:: t5x.losses - :members: diff --git a/t5x-main/docs/api_reference/t5x.main.rst b/t5x-main/docs/api_reference/t5x.main.rst deleted file mode 100644 index 9a602609a6d97889b87d505225e3e63cca5ff9c4..0000000000000000000000000000000000000000 --- a/t5x-main/docs/api_reference/t5x.main.rst +++ /dev/null @@ -1,7 +0,0 @@ -t5x.main binary -======================== - -.. currentmodule:: t5x.main - -.. automodule:: t5x.main - :members: diff --git a/t5x-main/docs/api_reference/t5x.metrics.rst b/t5x-main/docs/api_reference/t5x.metrics.rst deleted file mode 100644 index 613c6cf3694ea4b0ca2ffe0c8e979b112aee6a6b..0000000000000000000000000000000000000000 --- a/t5x-main/docs/api_reference/t5x.metrics.rst +++ /dev/null @@ -1,7 +0,0 @@ -t5x.metrics package -======================== - -.. currentmodule:: t5x.metrics - -.. automodule:: t5x.metrics - :members: diff --git a/t5x-main/docs/api_reference/t5x.models.rst b/t5x-main/docs/api_reference/t5x.models.rst deleted file mode 100644 index e16a7aedfe6d5d49702093974db84eb397fdb6b6..0000000000000000000000000000000000000000 --- a/t5x-main/docs/api_reference/t5x.models.rst +++ /dev/null @@ -1,7 +0,0 @@ -t5x.models package -======================== - -.. currentmodule:: t5x.models - -.. automodule:: t5x.models - :members: diff --git a/t5x-main/docs/api_reference/t5x.optimizers.rst b/t5x-main/docs/api_reference/t5x.optimizers.rst deleted file mode 100644 index 133b486da3e1ef2b8346ed64e3e3933d574e8e77..0000000000000000000000000000000000000000 --- a/t5x-main/docs/api_reference/t5x.optimizers.rst +++ /dev/null @@ -1,7 +0,0 @@ -t5x.optimizers package -======================== - -.. currentmodule:: t5x.optimizers - -.. automodule:: t5x.optimizers - :members: diff --git a/t5x-main/docs/api_reference/t5x.partitioning.rst b/t5x-main/docs/api_reference/t5x.partitioning.rst deleted file mode 100644 index 1cfeeb91f652355f4ffcf80999412e188168f4c2..0000000000000000000000000000000000000000 --- a/t5x-main/docs/api_reference/t5x.partitioning.rst +++ /dev/null @@ -1,7 +0,0 @@ -t5x.partitioning package -======================== - -.. currentmodule:: t5x.partitioning - -.. automodule:: t5x.partitioning - :members: diff --git a/t5x-main/docs/api_reference/t5x.state_utils.rst b/t5x-main/docs/api_reference/t5x.state_utils.rst deleted file mode 100644 index 1b6643c0345fb862a10b211ee2c65b926ad40dcf..0000000000000000000000000000000000000000 --- a/t5x-main/docs/api_reference/t5x.state_utils.rst +++ /dev/null @@ -1,7 +0,0 @@ -t5x.state_utils package -======================== - -.. currentmodule:: t5x.state_utils - -.. automodule:: t5x.state_utils - :members: diff --git a/t5x-main/docs/api_reference/t5x.test_utils.rst b/t5x-main/docs/api_reference/t5x.test_utils.rst deleted file mode 100644 index 13a5331e1f1286d1b8ed72193ec8713264380b85..0000000000000000000000000000000000000000 --- a/t5x-main/docs/api_reference/t5x.test_utils.rst +++ /dev/null @@ -1,7 +0,0 @@ -t5x.test_utils package -======================== - -.. currentmodule:: t5x.test_utils - -.. automodule:: t5x.test_utils - :members: diff --git a/t5x-main/docs/api_reference/t5x.train.rst b/t5x-main/docs/api_reference/t5x.train.rst deleted file mode 100644 index 435a32d06b16a04dd12dca5f47b2033316ed56a9..0000000000000000000000000000000000000000 --- a/t5x-main/docs/api_reference/t5x.train.rst +++ /dev/null @@ -1,7 +0,0 @@ -t5x.train binary -======================== - -.. currentmodule:: t5x.train - -.. automodule:: t5x.train - :members: \ No newline at end of file diff --git a/t5x-main/docs/api_reference/t5x.train_state.rst b/t5x-main/docs/api_reference/t5x.train_state.rst deleted file mode 100644 index 8850fa60f2bcdb063bd5029ded97de77c261714b..0000000000000000000000000000000000000000 --- a/t5x-main/docs/api_reference/t5x.train_state.rst +++ /dev/null @@ -1,7 +0,0 @@ -t5x.train_state package -======================== - -.. currentmodule:: t5x.train_state - -.. automodule:: t5x.train_state - :members: diff --git a/t5x-main/docs/api_reference/t5x.trainer.rst b/t5x-main/docs/api_reference/t5x.trainer.rst deleted file mode 100644 index e4d164707f655f18cb87c1fdef58b9dff6171266..0000000000000000000000000000000000000000 --- a/t5x-main/docs/api_reference/t5x.trainer.rst +++ /dev/null @@ -1,7 +0,0 @@ -t5x.trainer package -======================== - -.. currentmodule:: t5x.trainer - -.. automodule:: t5x.trainer - :members: diff --git a/t5x-main/docs/api_reference/t5x.utils.rst b/t5x-main/docs/api_reference/t5x.utils.rst deleted file mode 100644 index 40ad9b722897b623ecdfe6ec423c503f6f4946e8..0000000000000000000000000000000000000000 --- a/t5x-main/docs/api_reference/t5x.utils.rst +++ /dev/null @@ -1,7 +0,0 @@ -t5x.utils package -======================== - -.. currentmodule:: t5x.utils - -.. automodule:: t5x.utils - :members: diff --git a/t5x-main/docs/conf.py b/t5x-main/docs/conf.py deleted file mode 100644 index 2f5378272d2cbf00e83674814932154d7847d954..0000000000000000000000000000000000000000 --- a/t5x-main/docs/conf.py +++ /dev/null @@ -1,132 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Configuration file for the Sphinx documentation builder. - -This file only contains a selection of the most common options. For a full -list see the documentation: -https://www.sphinx-doc.org/en/master/usage/configuration.html -""" - -# pylint:disable=all -# -- Path setup -------------------------------------------------------------- - -# If extensions (or modules to document with autodoc) are in another directory, -# add these directories to sys.path here. If the directory is relative to the -# documentation root, use os.path.abspath to make it absolute, like shown here. -# -import os -import sys - -sys.path.insert(0, os.path.abspath('..')) - -# patch sphinx -import docs.conf_sphinx_patch - -# -- Project information ----------------------------------------------------- - -project = 'T5X' -copyright = '2023, The T5X authors' # pylint: disable=redefined-builtin -author = 'The T5X authors' - -# -- General configuration --------------------------------------------------- - -# Add any Sphinx extension module names here, as strings. They can be -# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom -# ones. -extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.autosummary', - 'sphinx.ext.autosectionlabel', - 'sphinx.ext.doctest', - 'sphinx.ext.intersphinx', - 'sphinx.ext.mathjax', - 'sphinx.ext.napoleon', - 'sphinx.ext.viewcode', - 'myst_nb', - 'sphinx_design', -] - -# The suffix(es) of source filenames. -# You can specify multiple suffix as a list of string: -# -source_suffix = ['.rst', '.ipynb', '.md'] - -autosummary_generate = True - -master_doc = 'index' - -autodoc_typehints = 'none' - -# Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] - -# List of patterns, relative to source directory, that match files and -# directories to ignore when looking for source files. -# This pattern also affects html_static_path and html_extra_path. -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] - -# -- Options for HTML output ------------------------------------------------- - -# The theme to use for HTML and HTML Help pages. See the documentation for -# a list of builtin themes. -# -# html_theme = 'pydata_sphinx_theme' -html_theme = 'sphinx_book_theme' -html_css_files = ['css/t5x_theme.css'] - -# The name of an image file (relative to this directory) to place at the top -# of the sidebar. -html_logo = './t5x.png' -html_favicon = './t5x.png' - -# title of the website -html_title = '' - -# Add any paths that contain custom static files (such as style sheets) here, -# relative to this directory. They are copied after the builtin static files, -# so a file named 'default.css' will overwrite the builtin 'default.css'. -html_static_path = ['_static'] - -html_theme_options = { - 'repository_url': 'https://github.com/google-research/t5x', - 'use_repository_button': True, # add a 'link to repository' button - 'use_issues_button': False, # add an 'Open an Issue' button - 'path_to_docs': ( - 'docs' - ), # used to compute the path to launch notebooks in colab - 'launch_buttons': { - 'colab_url': 'https://colab.research.google.com/', - }, - 'prev_next_buttons_location': None, - 'show_navbar_depth': 1, -} - -# -- Options for myst ---------------------------------------------- -# uncomment line below to avoid running notebooks during development -# nb_execution_mode = 'off' -# Notebook cell execution timeout; defaults to 30. -nb_execution_timeout = 100 -# List of patterns, relative to source directory, that match notebook -# files that will not be executed. -myst_enable_extensions = ['dollarmath'] -# raise exceptions on execution so CI can catch errors -nb_execution_allow_errors = False -nb_execution_raise_on_error = True - -# -- Extension configuration ------------------------------------------------- - -# Tell sphinx-autodoc-typehints to generate stub parameter annotations including -# types, even if the parameters aren't explicitly documented. -always_document_param_types = True diff --git a/t5x-main/docs/conf_sphinx_patch.py b/t5x-main/docs/conf_sphinx_patch.py deleted file mode 100644 index 3210e3e355b8a3372dcfbf125f68b29487a571bf..0000000000000000000000000000000000000000 --- a/t5x-main/docs/conf_sphinx_patch.py +++ /dev/null @@ -1,202 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Patch Sphinx to improve documentation aesthetics.""" - -# TODO(cgarciae): Send a PR to sphinx to upstream this fix. -# Issue: https://github.com/google/flax/issues/2196 -# This patch is needed to make autosummary provide the "annotations" -# variable so we can exclude function attributes from the methods list -# in flax_module.rst. The patch as such only adds this single line: -# -# ns['annotations'] = list(getattr(obj, '__annotations__', {}).keys())' -# -# We should consider sending a PR to sphinx so we can get rid of this. -# Original source: -# https://github.com/sphinx-doc/sphinx/blob/0aedcc9a916daa92d477226da67d33ce1831822e/sphinx/ext/autosummary/generate.py#L211-L351 -from typing import Any, Dict, List, Set, Tuple -import sphinx.ext.autodoc -import sphinx.ext.autosummary.generate as ag - - -# pylint:disable=all -def generate_autosummary_content( - name: str, - obj: Any, - parent: Any, - template: ag.AutosummaryRenderer, - template_name: str, - imported_members: bool, - app: Any, - recursive: bool, - context: Dict, - modname: str = None, - qualname: str = None, -) -> str: - doc = ag.get_documenter(app, obj, parent) - - def skip_member(obj: Any, name: str, objtype: str) -> bool: - try: - return app.emit_firstresult( - 'autodoc-skip-member', objtype, name, obj, False, {} - ) - except Exception as exc: - ag.logger.warning( - __( - 'autosummary: failed to determine %r to be documented, ' - 'the following exception was raised:\n%s' - ), - name, - exc, - type='autosummary', - ) - return False - - def get_class_members(obj: Any) -> Dict[str, Any]: - members = sphinx.ext.autodoc.get_class_members( - obj, [qualname], ag.safe_getattr - ) - return {name: member.object for name, member in members.items()} - - def get_module_members(obj: Any) -> Dict[str, Any]: - members = {} - for name in ag.members_of(obj, app.config): - try: - members[name] = ag.safe_getattr(obj, name) - except AttributeError: - continue - return members - - def get_all_members(obj: Any) -> Dict[str, Any]: - if doc.objtype == 'module': - return get_module_members(obj) - elif doc.objtype == 'class': - return get_class_members(obj) - return {} - - def get_members( - obj: Any, - types: Set[str], - include_public: List[str] = [], - imported: bool = True, - ) -> Tuple[List[str], List[str]]: - items: List[str] = [] - public: List[str] = [] - - all_members = get_all_members(obj) - for name, value in all_members.items(): - documenter = ag.get_documenter(app, value, obj) - if documenter.objtype in types: - # skip imported members if expected - if imported or getattr(value, '__module__', None) == obj.__name__: - skipped = skip_member(value, name, documenter.objtype) - if skipped is True: - pass - elif skipped is False: - # show the member forcedly - items.append(name) - public.append(name) - else: - items.append(name) - if name in include_public or not name.startswith('_'): - # considers member as public - public.append(name) - return public, items - - def get_module_attrs(members: Any) -> Tuple[List[str], List[str]]: - """Find module attributes with docstrings.""" - attrs, public = [], [] - try: - analyzer = ag.ModuleAnalyzer.for_module(name) - attr_docs = analyzer.find_attr_docs() - for namespace, attr_name in attr_docs: - if namespace == '' and attr_name in members: - attrs.append(attr_name) - if not attr_name.startswith('_'): - public.append(attr_name) - except ag.PycodeError: - pass # give up if ModuleAnalyzer fails to parse code - return public, attrs - - def get_modules(obj: Any) -> Tuple[List[str], List[str]]: - items: List[str] = [] - for _, modname, _ispkg in ag.pkgutil.iter_modules(obj.__path__): - fullname = name + '.' + modname - try: - module = ag.import_module(fullname) - if module and hasattr(module, '__sphinx_mock__'): - continue - except ImportError: - pass - - items.append(fullname) - public = [x for x in items if not x.split('.')[-1].startswith('_')] - return public, items - - ns: Dict[str, Any] = {} - ns.update(context) - - if doc.objtype == 'module': - scanner = ag.ModuleScanner(app, obj) - ns['members'] = scanner.scan(imported_members) - ns['functions'], ns['all_functions'] = get_members( - obj, {'function'}, imported=imported_members - ) - ns['classes'], ns['all_classes'] = get_members( - obj, {'class'}, imported=imported_members - ) - ns['exceptions'], ns['all_exceptions'] = get_members( - obj, {'exception'}, imported=imported_members - ) - ns['attributes'], ns['all_attributes'] = get_module_attrs(ns['members']) - ispackage = hasattr(obj, '__path__') - if ispackage and recursive: - ns['modules'], ns['all_modules'] = get_modules(obj) - elif doc.objtype == 'class': - ns['members'] = dir(obj) - ns['inherited_members'] = set(dir(obj)) - set(obj.__dict__.keys()) - ns['methods'], ns['all_methods'] = get_members( - obj, {'method'}, ['__init__'] - ) - ns['attributes'], ns['all_attributes'] = get_members( - obj, {'attribute', 'property'} - ) - ns['annotations'] = list(getattr(obj, '__annotations__', {}).keys()) - - if modname is None or qualname is None: - modname, qualname = ag.split_full_qualified_name(name) - - if doc.objtype in ('method', 'attribute', 'property'): - ns['class'] = qualname.rsplit('.', 1)[0] - - if doc.objtype in ('class',): - shortname = qualname - else: - shortname = qualname.rsplit('.', 1)[-1] - - ns['fullname'] = name - ns['module'] = modname - ns['objname'] = qualname - ns['name'] = shortname - - ns['objtype'] = doc.objtype - ns['underline'] = len(name) * '=' - - if template_name: - return template.render(template_name, ns) - else: - return template.render(doc.objtype, ns) - - -ag.generate_autosummary_content = generate_autosummary_content diff --git a/t5x-main/docs/contributions.md b/t5x-main/docs/contributions.md deleted file mode 100644 index eb6f1e77f2c9f3de520a0138795e572568890f1f..0000000000000000000000000000000000000000 --- a/t5x-main/docs/contributions.md +++ /dev/null @@ -1,64 +0,0 @@ -# Contributions - -T5X was developed as part of the T5 Infrastructure effort at Google Research. - -Adam Roberts founded and leads the project, designed and wrote much of `seqio` -and `t5x`, and co-authored the -[T5X and SeqIO paper](https://arxiv.org/abs/2203.17189). Hyung Won Chung -designed and wrote much of `t5x`, led its open sourcing, and co-authored the -paper. Anselm Levskaya built the initial prototype for `t5x` and wrote much of -the code. Gaurav Mishra leads `seqio`, implemented deterministic pipelines, and -co-authored the paper. James Bradbury implemented partitioning in `t5x` and -co-wrote the paper. - -Daniel Andor, Sharan Narang, Brian Lester, Colin Gaffney, Afroz Mohiuddin, -Curtis Hawthorne, Aitor Lewkowycz, Alex Salcianu, Marc van Zee, Jacob Austin, -Sebastian Good-man, Livio Baldini Soares, Haitang Hu, Sasha Tsvyashchenko, -Aakanksha Chowdhery, Jasmijn Bastings, Jannis Bulian, Xavier Garcia, Jianmo Ni, -Andrew Chen, Kathleen Kenealy, Kehang Han, Jonathan H. Clark, Stephan Lee, Dan -Garrette, and James Lee-Thorp made substantial code contributions. - -Colin Raffel and Noam Shazeer helped design `seqio`. Marvin Ritter advised on -deterministic pipelines and the use of CLU Metrics. Maarten Bosma helped design -deterministic pipelines. Jeremy Maitin-Shepard advised on the use of -TensorStore. Alexandre Passos and Ryan Sepassi advised on overall technical -design. - -Noah Fiedel is a member of the leadership team, contributed to the high level -design and roadmap, and co-wrote the paper. Mark Omernick, Brennan Saeta, Ryan -Sepassi, Alexander Spiridonov (Product Manager), and Josh Newlan (Technical -Program Manager) are members of the leadership team and co-wrote the paper. -Andrea Gesmundo is a member of the leadership team and contributed to the -internal infrastructure component. - -Thanks to the many other contributors to the project: Ian Simon, Reiner Pope, -Vincent Zhao, Pierre Ruyssen, Linting Xue, Junwhan Ahn, Barret Zoph, David -Dohan, Masumi Parekh, Chang Lan, Frederick Liu, Julien Amelot, Luheng He, Fede -Lebron, RebeccaChen, Anosh Raj, Mandy Guo, Ethan Dyer, Mihai Tiuca, Hongkun Yu, -Kevin Brooks, David Soergel, Kelvin Guu, Joshua Ainslie, Luyao Xu, Ji Ma, Josh -Gardner, Daphne Ippolito, Peter Hawkins, Bo Pang, Marc Rasi, Wei Li, Wenhu Chen, -Iulia Turc, John Wieting, Alex Passos, Zonglin Li, Katie Everett, Olivier -Bachem, Francesco Piccinno, Jakub Adamek, Jonathan Heek, Parker Schuh, Hexiang -Hu, Du Phan, Max Moroz, David Miller, Ryan Doherty, David Elworthy, Alfonso -Casta ̃no, Julian Eisenschlos, Vlad-Doru Ion, Lucas Dixon, Ron Shapiro, Dinghua -Li, Aaron Parisi, Xi Chen, Nan Ding, Chung-ching Chang, Timothy Dozat, Natalia -Ponomareva, Delesley Hutchins, Ankush Garg, Yu-Han Liu, Mehrdad Khatir, Costanza -Conforti, Philipp Keck, Rapha ̈el Marinier, Marie Pellat, Raghuram Vadapalli, -Joshua Maynez, Yi Tay, Xihui Wu, David Belanger, Luke Metz, Dan Zheng, Deepti -Bhatia, Hariharan Shanmugavadivel, Rewon Child, Rigel Swavely, Mihir Sanjay -Kale, Arash Afkanpour, Roberto Rama, Juro Gottweis, Jonathan Herzig, Yilei Yang, -Elias Mizan, Pedram Pejman, Jiayu Ye, Smit Sanghavi, Rahul Joshi, Ziqiang Feng, -Charles Sutton, Weikang Zhou, Liam Fedus, Shanqing Cai, Ginger Perng, Yash -Katariya, Urvashi Khandelwal, Sebastian Gehrmann, Edward Loper, Tianze Shi, Luke -Vilnis, Amelia Archer, Tom Weingarten, David Zats, Murtaza Dhuliawala, Xin Xie, -Sahil Dua, Andr ́e SusanoPinto, Piotr Padlewski, Sascha Rothe, Erik Aas, Felix -Stahlberg, Ken Durden, Christina Sorokin, Jaehoon Lee, Roy Frostig, Jacob -Devlin, Jorge Gonzalez Mendez, Deepak Ramachandran, Santiago Ontanon, Karthik -Raman, Yi Sun, Ali Elqursh, Reuben La Haye,Adam Fahrenkopf, Alex Polozov, Vinay -Ramasesh, Ian Tenney. - -Thanks to NVIDIA for GPU contributions: Sahil Jain, Terry Kong, Yu-Hang Tang, -Ming Huang, Frederic Bastien, Sharath Turuvekere Sreenivas, Xiaowei Ren, Ryan Jeng, - Reese Wang - -Thanks to Douglas Eck and Zoubin Ghahramani for sponsoring the project. diff --git a/t5x-main/docs/index.md b/t5x-main/docs/index.md deleted file mode 100644 index 45317f0b39466a0b149dbaa9eec87aa90404d074..0000000000000000000000000000000000000000 --- a/t5x-main/docs/index.md +++ /dev/null @@ -1,65 +0,0 @@ -# T5X - - -Note: T5X is community-supported since ~2023. For critical use cases, consider -using libraries like TuneLab (go/tunelab) and Gemax Prod (go/gemax-prod). See -https://github.com/google-research/text-to-text-transfer-transformer/blob/main/README.mdx-to-gemax-prod for useful tips on transitioning. - -## Overview - -T5X is a modular, composable, research-friendly framework for high-performance, -configurable, self-service training, evaluation, and inference of sequence -models (starting with language) at many scales. - -It is essentially a new and improved implementation of the -[T5 codebase](https://github.com/google-research/text-to-text-transfer-transformer/blob/main/README.md) (based on Mesh TensorFlow) in JAX and Flax. To learn -more, see the [T5X Paper](https://arxiv.org/abs/2203.17189). - -## Getting Started - -Here are some quick tutorials to help you get started with common use-cases on -T5X: - -#### [Introductory Colabs](tutorials.md) - -If you are new to T5X, we recommend starting with our introductory Colab series, -which introduces core concepts of both T5X and SeqIO. More colabs will be added -to this series regularly! - -#### [Fine-tuning a model](usage/finetune.md) - -This tutorial outlines the steps to fine-tune an existing pre-trained model with -T5X on common downstream Tasks/Mixtures available on SeqIO. This is one of the -simplest and most common use cases of T5X. If you're new to T5X, this tutorial -is the recommended starting point. - -#### [Running evaluation on a model](usage/eval.md) - -This tutorial outlines the steps to evaluate a model with T5X on downstream -Tasks/Mixtures defined in SeqIO. - -#### [Running inference on a model](usage/infer.md) - -This tutorial outlines the steps to run inference on a model with T5X. - -#### [Training a model from scratch](usage/pretrain.md) - -This tutorial outlines the steps to pretrain a model with T5X on Tasks/Mixtures -defined in SeqIO. - -#### [Gin Primer](usage/gin.md) - -This tutorial provides a quick introduction to Gin, a lightweight configuration -framework for Python that is used to configure training, eval and inference jobs -on T5X. - -#### [Partitioning Primer](usage/partitioning.md) - -This tutorial provides background on what model and data partitioning are and -how it can be configured in T5X. - -#### [Metrics Overview](usage/metrics.md) - -This tutorial provides an overview of how metrics can be used and customized to -evaluate T5X models. - diff --git a/t5x-main/docs/index.rst b/t5x-main/docs/index.rst deleted file mode 100644 index 360a30051e8b60ce22eb9a30fb094ab28eab6396..0000000000000000000000000000000000000000 --- a/t5x-main/docs/index.rst +++ /dev/null @@ -1,24 +0,0 @@ -****************************** -T5X -****************************** - - -T5X is a modular, composable, research-friendly framework for high-performance, -configurable, self-service training, evaluation, and inference of sequence -models (starting with language) at many scales. - -It is essentially a new and improved implementation of the -`T5 codebase `__ -(based on Mesh TensorFlow) in JAX and Flax. To learn more, see the -`T5X Paper `__. - -.. toctree:: - :maxdepth: 2 - :caption: Table of Contents - - Quick Start - Tutorials - Usage Guides - Models - api_reference/index - contributions diff --git a/t5x-main/docs/models.md b/t5x-main/docs/models.md deleted file mode 100644 index 26988ea41de4cbfcfbb21fe8458d0dc0e0555928..0000000000000000000000000000000000000000 --- a/t5x-main/docs/models.md +++ /dev/null @@ -1,318 +0,0 @@ -# Models - - -This page lists the available pre-trained T5 models. To use a pre-trained model, -you need a Gin config file that defines the model params, and the model -checkpoint to load from. For your convenience, TensorFlow checkpoints and Gin -configs for common T5 pre-trained models have been made available for use in -T5X. Following is a list of these pre-trained models and their Gin and -checkpoint locations. - -+ All checkpoints: - [`gs://t5-data/pretrained_models/t5x/`](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/) -+ All Gin files: - [`t5x/configs/models/`](https://github.com/google-research/t5x/blob/main/t5x/configs/) - -### Selecting a model: - -Publicly Available Models: - -Model | Use Case ----------------------------------------------------- | -------- -[T5 1.1](#t5-11-checkpoints) | Improved T5, recommended for most research. English only. -[T5](#t5-checkpoints) | The original T5 work for reproducibility. English only. -[T5 1.1 LM-Adapted](#t5-11-lm-adapted-checkpoints) | Trained for 100k additional steps on the LM objective, per [prompt tuning paper](https://arxiv.org/abs/2104.08691). -[mT5](#mt5-checkpoints) | Multilingual T5. Recommended for multilingual research. Note that at smaller scales (at least through XL), mT5 performance is lower than T5 on English tasks. -[mT5 LM-Adapted](#mt5-lm-adapted-checkpoints) | Trained for 100k additional steps on the LM objective, per [zero-shot cross-lingual generation (XGen) paper](https://arxiv.org/abs/2205.12647). -[umT5](#umt5-checkpoints) | umT5, an updated mT5 model trained using a more uniform language distribution, per [the UniMax paper](https://openreview.net/forum?id=kXwdL1cWOAi). -[ByT5](#byt5-checkpoints) | ByT5. A "token-free" model that uses UTF-8 bytes for input and output. Recommended for tasks involving word-internal phenomena such as spelling, pronunciation, or morphology. -[LongT5](#longt5-checkpoints) | Recommended checkpoints to fine-tune for long input sequence tasks -[MoE](#mixture-of-experts-moe-checkpoints) | Useful for MoE experimentation. -[Flan-T5](#flan-t5-checkpoints) | General purpose T5 checkpoints for few-shot and finetuning. We recommend Flan-T5 over vanilla T5 and T5 LM-adapted -[UL2](#ul2-checkpoints) | Checkpoints for 20B pretrained and FLAN-based instruction-tuned models using the UL2 objective from [UL2 paper](https://arxiv.org/abs/2205.05131) -[BigScience](#bigscience-checkpoints) | Checkpoints from the [BigScience paper](https://arxiv.org/abs/2204.05832) -[FLIP](#flip-checkpoints) | Language-Image models trained with an alternative to CLIP, presented in the [FLIP paper](https://arxiv.org/abs/2212.00794) -[RankGen](#rankgen-checkpoints) | 1.2B parameter encoder model for English to score model generations given a prefix for decoding from the [RankGen paper](https://arxiv.org/abs/2205.09726) -[Dipper](#dipper-checkpoints) | 11B parameter paraphrase generation model from the [Dipper paper](https://arxiv.org/abs/2303.13408) - - -### Public Research Models - -#### T5 Checkpoints - -These are the checkpoints used in the paper [Exploring the Limits of Transfer -Learning with a Unified Text-to-Text -Transformer](https://arxiv.org/abs/1910.10683). They are encoder-decoder models -pre-trained on [C4](https://www.tensorflow.org/datasets/catalog/c4) with a "span -corruption" denoising objective, in addition to a mixture of downstream tasks -including: GLUE, SuperGLUE, CNN/Daily Mail, SQuAD, and WMT. - -**Vocabulary:** -[cc_all.32000.100extra](https://console.cloud.google.com/storage/browser/t5-data/vocabs/cc_all.32000.100extra) - -Model | Gin File Location | Checkpoint Location --------- | ------------------------------------------------------------------------------ | ------------------- -T5 Small | [t5_small.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_0/small.gin) | [gs://t5-data/pretrained_models/t5x/t5_small/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/t5_small) -T5 Base | [t5_base.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_0/base.gin) | [gs://t5-data/pretrained_models/t5x/t5_base/checkpoint_999900](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/t5_base) -T5 Large | [t5_large.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_0/large.gin) | [gs://t5-data/pretrained_models/t5x/t5_large/checkpoint_1000700](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/t5_large) -T5 3B | [t5_3B.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_0/3B.gin) | [gs://t5-data/pretrained_models/t5x/t5_3B/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/t5_3B) -T5 11B | [t5_11B.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_0/11B.gin) | [gs://t5-data/pretrained_models/t5x/t5_11B/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/t5_11B) - -#### T5 1.1 Checkpoints - -These are similar to the models from [Exploring the Limits of Transfer Learning -with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683), but -with the following improvements: - -* GEGLU activation in feed-forward hidden layer, rather than ReLU - see - https://arxiv.org/abs/2002.05202 . -* Dropout was turned off in pre-training (quality win). Dropout should be - re-enabled during fine-tuning. -* Pre-trained on C4 only without mixing in the downstream tasks. -* no parameter sharing between embedding and classifier layer -* "xl" and "xxl" replace "3B" and "11B". The model shapes are a bit - different - larger d_model and smaller num_heads and d_ff. - -For English-language, sequence-to-sequence-style tasks (ones where the goal is -to map from an input text sequence to a target sequence) these are usually the -best models to fine-tune. - -**Vocabulary:** -[cc_all.32000.100extra](https://console.cloud.google.com/storage/browser/t5-data/vocabs/cc_all.32000.100extra) - -Model | Gin File Location | Checkpoint Location ------------- | ---------------------------------------------------------------------------------- | ------------------- -T5 1.1 Small | [t5_1_1/small.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_1/small.gin) | [gs://t5-data/pretrained_models/t5x/t5_1_1_small/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/t5_1_1_small) -T5 1.1 Base | [t5_1_1/base.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_1/base.gin) | [gs://t5-data/pretrained_models/t5x/t5_1_1_base/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/t5_1_1_base) -T5 1.1 Large | [t5_1_1_large.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_1/large.gin) | [gs://t5-data/pretrained_models/t5x/t5_1_1_large/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/t5_1_1_large) -T5 1.1 XL | [t5_1_1_xl.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_1/xl.gin) | [gs://t5-data/pretrained_models/t5x/t5_1_1_xl/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/t5_1_1_xl) -T5 1.1 XXL | [t5_1_1_xxl.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_1/xxl.gin) | [gs://t5-data/pretrained_models/t5x/t5_1_1_xxl/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/t5_1_1_xxl) - -#### T5 1.1 LM-Adapted Checkpoints - -These "LM-adapted" models are initialized from T5 1.1 (above) and trained for an -additional 100K steps on the LM objective discussed in the -[T5 paper](https://arxiv.org/abs/1910.10683). This adaptation improves the -ability of the model to be used for -[prompt tuning](https://arxiv.org/abs/2104.08691). These checkpoints were also -used within the BigScience [T0](https://arxiv.org/abs/2110.08207) project. - -**Vocabulary:** -[cc_all.32000.100extra](https://console.cloud.google.com/storage/browser/t5-data/vocabs/cc_all.32000.100extra) - -Model | Gin File Location | Checkpoint Location --------------------- | ------------------------------------------------------------------------------------------------------------------- | ------------------- -T5 1.1 LM-100K Small | [t5_1_1_small.gin](https://github.com/google-research/t5x/blob/main/t5x/google/examples/flaxformer_t5/configs/models/t5_1_1_small.gin) | [t5_1_1_lm100k_small/checkpoint_1100000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/t5_1_1_lm100k_small) -T5 1.1 LM-100K Base | [t5_1_1_base.gin](https://github.com/google-research/t5x/blob/main/t5x/google/examples/flaxformer_t5/configs/models/t5_1_1_base.gin) | [t5_1_1_lm100k_base/checkpoint_1100000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/t5_1_1_lm100k_base) -T5 1.1 LM-100K Large | [t5_1_1_large.gin](https://github.com/google-research/t5x/blob/main/t5x/google/examples/flaxformer_t5/configs/models/t5_1_1_large.gin) | [t5_1_1_lm100k_large/checkpoint_1100000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/t5_1_1_lm100k_large) -T5 1.1 LM-100K XL | [t5_1_1_xl.gin](https://github.com/google-research/t5x/blob/main/t5x/google/examples/flaxformer_t5/configs/models/t5_1_1_xl.gin) | [t5_1_1_lm100k_xl/checkpoint_1100000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/t5_1_1_lm100k_xl) -T5 1.1 LM-100K XXL | [t5_1_1_xxl.gin](https://github.com/google-research/t5x/blob/main/t5x/google/examples/flaxformer_t5/configs/models/t5_1_1_xxl.gin) | [t5_1_1_lm100k_xxl/checkpoint_1100000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/t5_1_1_lm100k_xxl) - - -#### mT5 Checkpoints - -These are the checkpoints used in the paper -[mT5: A Massively Multilingual Pre-trained Text-to-Text Transformer](https://aclanthology.org/2021.naacl-main.41/). -They are encoder-decoder models trained on -[multilingual C4](https://www.tensorflow.org/datasets/catalog/c4#c4multilingual) -with a denoising objective. These are the best checkpoints to fine-tune for -non-English sequence-to-sequence tasks. - -**Vocabulary:** -[mc4.250000.100extra](https://console.cloud.google.com/storage/browser/t5-data/vocabs/mc4.250000.100extra) - -Model | Gin File Location | Checkpoint Location ---------- | ---------------------------------------------------------------------------- | ------------------- -mT5 Small | [mt5/small.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/mt5/small.gin) | [gs://t5-data/pretrained_models/t5x/mt5_small/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/mt5_small) -mT5 Base | [mt5/base.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/mt5/base.gin) | [gs://t5-data/pretrained_models/t5x/mt5_base/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/mt5_base) -mT5 Large | [mt5/large.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/mt5/large.gin) | [gs://t5-data/pretrained_models/t5x/mt5_large/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/mt5_large) -mT5 XL | [mt5/xl.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/mt5/xl.gin) | [gs://t5-data/pretrained_models/t5x/mt5_xl/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/mt5_xl) -mT5 XXL | [mt5/xxl.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/mt5/xxl.gin) | [gs://t5-data/pretrained_models/t5x/mt5_xxl/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/mt5_xxl) - -#### mT5 LM-Adapted Checkpoints - -These are the checkpoints released as part of the -[zero-shot cross-lingual generation (XGen) paper](https://arxiv.org/abs/2205.12647). - -These "LM-adapted" models are initialized from mT5 (above) and trained for an -additional 100K steps on the LM objective discussed in the -[T5 paper](https://arxiv.org/abs/1910.10683). - -This adaptation improves the ability of the model to be used for -[prompt tuning](https://arxiv.org/abs/2104.08691). - -**Vocabulary:** -[mc4.250000.100extra](https://console.cloud.google.com/storage/browser/t5-data/vocabs/mc4.250000.100extra) - -Model | Gin File Location | Checkpoint Location --------------------- | ---------------------------------------------------------------------------- | ------------------- -mT5 LM-Adapted Small | [mt5/small.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/mt5/small.gin) | [mt5_lm_adapted/small/checkpoint_1100000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/mt5_lm_adapted/small/checkpoint_1100000) -mT5 LM-Adapted Base | [mt5/base.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/mt5/base.gin) | [mt5_lm_adapted/base/checkpoint_1100000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/mt5_lm_adapted/base/checkpoint_1100000) -mT5 LM-Adapted Large | [mt5/large.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/mt5/large.gin) | [mt5_lm_adapted/large/checkpoint_1100000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/mt5_lm_adapted/large/checkpoint_1100000) -mT5 LM-Adapted XL | [mt5/xl.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/mt5/xl.gin) | [mt5_lm_adapted/xl/checkpoint_1100000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/mt5_lm_adapted/xl/checkpoint_1100000) -mT5 LM-Adapted XXL | [mt5/xxl.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/mt5/xxl.gin) | [mt5_lm_adapted/xxl/checkpoint_1100000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/mt5_lm_adapted/xxl/checkpoint_1100000) - -#### umT5 Checkpoints - -These are the checkpoints described in the paper [UniMax: Fairer and More -Effective Language Sampling for Large-Scale Multilingual -Pretraining](https://openreview.net/forum?id=kXwdL1cWOAi). umT5 is similar to -mT5 (see above); both are multilingual encoder-decoder models ranging from 300M -to 13B parameters, trained on the mC4 corpus using a denoising objective. umT5 -is trained on a fresher version of the mC4 corpus (3.1.0), and with a more -uniform language balancing strategy. - -**Vocabulary:** [umt5.256000](https://console.cloud.google.com/storage/browser/t5-data/vocabs/umt5.256000) - -Model | Gin File Location | Checkpoint Location ----------- | --------------------------------------------------------------------------------------------------------- | ------------------- -umT5 Small | [umt5/pretrain_small.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/scalable_t5/umt5/pretrain_small.gin) | [umt5/small/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/umt5/small/checkpoint_1000000) -umT5 Base | [umt5/pretrain_base.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/scalable_t5/umt5/pretrain_base.gin) | [umt5/base/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/umt5/base/checkpoint_1000000) -umT5 XL | [umt5/pretrain_xl.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/scalable_t5/umt5/pretrain_xl.gin) | [umt5/xl/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/umt5/xl/checkpoint_1000000) -umT5 XXL | [umt5/pretrain_xxl.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/scalable_t5/umt5/pretrain_xxl.gin) | [umt5/xxl/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/umt5/xxl/checkpoint_1000000) - -#### ByT5 Checkpoints - -These are the checkpoints used in the paper -[ByT5: Towards a Token-Free Future with Pre-trained Byte-to-Byte Models](https://aclanthology.org/2022.tacl-1.17/). -They are similar to mT5 (above), but are "token-free", processing text as raw -UTF-8 bytes, as opposed to using a pretrained subword vocabulary. These models -are more robust to character-level noise, and outperform parameter-matched mT5 -models in many settings, particularly on word-level tasks sensitive to spelling, -pronunciation, or morphology. However inference is significantly slower, up to -10x depending on the task. - -**Vocabulary:** None - -Model | Gin File Location | Checkpoint Location ----------- | ------------------------------------------------------------------------------ | ------------------- -ByT5 Small | [byt5/small.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/byt5/small.gin) | [gs://t5-data/pretrained_models/t5x/byt5_small/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/byt5_small) -ByT5 Base | [byt5/base.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/byt5/base.gin) | [gs://t5-data/pretrained_models/t5x/byt5_base/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/byt5_base) -ByT5 Large | [byt5/large.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/byt5/large.gin) | [gs://t5-data/pretrained_models/t5x/byt5_large/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/byt5_large) -ByT5 XL | [byt5/xl.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/byt5/xl.gin) | [gs://t5-data/pretrained_models/t5x/byt5_xl/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/byt5_xl) -ByT5 XXL | [byt5/xxl.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/byt5/xxl.gin) | [gs://t5-data/pretrained_models/t5x/byt5_xxl/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/byt5_xxl) - -#### LongT5 Checkpoints - -These are the checkpoints used in the paper -[LongT5: Efficient Text-to-Text Transformer for Long Sequences](https://arxiv.org/abs/2112.07916). -They are encoder-decoder models trained on -[C4](https://www.tensorflow.org/datasets/catalog/c4) using the PEGASUS Principle -Sentences Generation objective. These are the recommended checkpoints to -fine-tune for long input sequence tasks. - -##### LongT5 Local Attention Checkpoints - -The checkpoints below use local attention, which uses a sliding window to reduce -training time from quadratic (with regards to input length) to linear. These are -the recommended checkpoints to use for faster training/inference time. - -**Vocabulary:** -[cc_all.32000.100extra](https://console.cloud.google.com/storage/browser/t5-data/vocabs/cc_all.32000.100extra) - -Model | Gin File Location | Checkpoint Location ----------------------------- | ------------------------------------------------------------------------------------------------------------------------------------- | ------------------- -LongT5 Local Attention Base | [longt5/models/longt5_1_1_base.gin](https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/longt5/models/longt5_1_1_base.gin) | [gs://t5-data/pretrained_models/t5x/longt5/local_base/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/longt5/local_base) -LongT5 Local Attention Large | [longt5/models/longt5_1_1_large.gin](https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/longt5/models/longt5_1_1_large.gin) | [gs://t5-data/pretrained_models/t5x/longt5/local_large/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/longt5/local_large) - -##### LongT5 Transient Global Attention Checkpoints - -The checkpoints below use transient global attention, which introduces global -tokens at each encoder layer to allow tokens to interact with each other at -longer distances. These are the recommended checkpoints to use for increased -performance on long input sequence tasks. - -**Vocabulary:** -[cc_all.32000.100extra](https://console.cloud.google.com/storage/browser/t5-data/vocabs/cc_all.32000.100extra) - -Model | Gin File Location | Checkpoint Location ------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------- -LongT5 Base | [longt5/models/longt5_1_1_transient_base.gin](https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/longt5/models/longt5_1_1_transient_global_base.gin) | [gs://t5-data/pretrained_models/t5x/longt5/tglobal_base/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/longt5/tglobal_base) -LongT5 Large | [longt5/models/longt5_1_1_transient_large.gin](https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/longt5/models/longt5_1_1_transient_global_large.gin) | [gs://t5-data/pretrained_models/t5x/longt5/tglobal_large/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/longt5/tglobal_large) -LongT5 XL | [longt5/models/longt5_1_1_transient_xl.gin](https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/longt5/models/longt5_1_1_transient_global_xl.gin) | [gs://t5-data/pretrained_models/t5x/longt5/tglobal_xl/checkpoint_1000000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/longt5/tglobal_xl) - -#### Mixture of Experts (MoE) Checkpoints - -These MoE checkpoints need to be used with T5X MoE overrides -- specifically, -the MoeTrainer and the MoePjitPartitioner. For example, for fine-tuning, use the -[MoE fine-tune run config](https://github.com/google-research/t5x/blob/main/t5x/contrib/moe/configs/runs/finetune.gin). - - -##### Converted Mesh Tensorflow checkpoints - -[Switch Transformer model](https://arxiv.org/abs/2101.03961). - -**Vocabulary:** -[cc_all.32000.100extra](https://console.cloud.google.com/storage/browser/t5-data/vocabs/cc_all.32000.100extra) - - -Model | Gin File Location | Checkpoint Location ----------------------------------------- | ------------------------------------------------------------------------------------------------------------ | ------------------- -Switch Transformer Base 8 Experts | [switch_base.gin](https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/moe/models/switch_base.gin) | [gs://t5-data/pretrained_models/t5x/moe/switch_classic/base/e8/checkpoint_500100](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/moe/switch_classic/base/e8) -Switch Transformer Base 16 Experts | [switch_base.gin](https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/moe/models/switch_base.gin) | [gs://t5-data/pretrained_models/t5x/moe/switch_classic/base/e16/checkpoint_550000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/moe/switch_classic/base/e16) -Switch Transformer Base 32 Experts | [switch_base.gin](https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/moe/models/switch_base.gin) | [gs://t5-data/pretrained_models/t5x/moe/switch_classic/base/e32/checkpoint_550000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/moe/switch_classic/base/e32) -Switch Transformer Base 64 Experts | [switch_base.gin](https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/moe/models/switch_base.gin) | [gs://t5-data/pretrained_models/t5x/moe/switch_classic/base/e64/checkpoint_550000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/moe/switch_classic/base/e64) -Switch Transformer Base 128 Experts | [switch_base.gin](https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/moe/models/switch_base.gin) | [gs://t5-data/pretrained_models/t5x/moe/switch_classic/base/e128/checkpoint_550000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/moe/switch_classic/base/e128) -Switch Transformer Base 256 Experts | [switch_base.gin](https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/moe/models/switch_base.gin) | [gs://t5-data/pretrained_models/t5x/moe/switch_classic/base/e256/checkpoint_550000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/moe/switch_classic/base/e256) -Switch Transformer Large 128 Experts | [switch_large.gin](https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/moe/models/switch_large.gin) | [gs://t5-data/pretrained_models/t5x/moe/switch_classic/large/e128/checkpoint_483100](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/moe/switch_classic/large/e128) -Switch Transformer XXL 128 Experts | [switch_xxl.gin](https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/moe/models/switch_xxl.gin) | [gs://t5-data/pretrained_models/t5x/moe/switch_classic/xxl/e128/checkpoint_634600](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/moe/switch_classic/xxl/e128) -Switch Transformer C 2048 Experts (1.6T) | [switch_c.gin](https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/moe/models/switch_c.gin) | [gs://t5-data/pretrained_models/t5x/moe/switch_classic/c/e2048/checkpoint_611800](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/moe/switch_classic/c/e2048) - - -#### Flan-T5 Checkpoints - -These are the checkpoints released as part of the paper -[Scaling Instruction-Finetuned Language Models](https://arxiv.org/abs/2210.11416). -They were initialized from the -[T5 1.1 LM-Adapted](#t5-11-lm-adapted-checkpoints) and instruction-finetuned. - -They significantly outperform the LM-adapted checkpoints. For example, -Flan-T5-XXL outperforms T5-LM-XXL by 26.6% absolute on the normalized average -score. It even outperforms a much larger PaLM 62B model on -[BigBench Hard](https://arxiv.org/abs/2210.09261) a set of challenging BigBench -benchmark. - -Unlike the vanilla T5 checkpoints, these can be directly used for few-shot -prompting as well as standard finetuning. See -[Chung et al. 2022](https://arxiv.org/abs/2210.11416) for details. - -Model | Gin File Location | Checkpoint Location -------------- | ---------------------------------------------------------------------------------- | ------------------- -Flan-T5 Small | [t5_1_1/small.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_1/small.gin) | [gs://t5-data/pretrained_models/t5x/flan_t5_small/checkpoint_1198000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/flan_t5_small/checkpoint_1198000) -Flan-T5 Base | [t5_1_1/base.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_1/base.gin) | [gs://t5-data/pretrained_models/t5x/flan_t5_base/checkpoint_1184000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/flan_t5_base/checkpoint_1184000) -Flan-T5 Large | [t5_1_1_large.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_1/large.gin) | [gs://t5-data/pretrained_models/t5x/flan_t5_large/checkpoint_1164000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/flan_t5_large/checkpoint_1164000) -Flan-T5 XL | [t5_1_1_xl.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_1/xl.gin) | [gs://t5-data/pretrained_models/t5x/flan_t5_xl/checkpoint_1138000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/flan_t5_xl/checkpoint_1138000) -Flan-T5 XXL | [t5_1_1_xxl.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_1/xxl.gin) | [gs://t5-data/pretrained_models/t5x/flan_t5_xxl/checkpoint_1114000](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/flan_t5_xxl/checkpoint_1114000) - -#### UL2 Checkpoints - -Checkpoints for 20B pretrained and FLAN-based instruction-tuned models using the -UL2 objective from [UL2 paper](https://arxiv.org/abs/2205.05131). Checkpoints -are released at -https://github.com/google-research/google-research/tree/master/ul2#checkpoints. - -#### BigScience Checkpoints - -Checkpoints from the [BigScience paper](https://arxiv.org/abs/2204.05832), -released at -https://github.com/bigscience-workshop/architecture-objective/tree/main#checkpoints. - -#### FLIP Checkpoints - -Language-Image models trained with an alternative to CLIP, presented in the -[FLIP paper](https://arxiv.org/abs/2212.00794). Checkpoints are released at -https://github.com/facebookresearch/flip#results-and-pre-trained-flip-models. - -#### RankGen Checkpoints - -1.2B parameter encoder model for English to score model generations given a -prefix for decoding from the [RankGen paper](https://arxiv.org/abs/2205.09726). -Checkpoints are released at -https://github.com/google-research/google-research/tree/master/rankgen. - -#### Dipper Checkpoints - -11B parameter paraphrase generation model from the -[Dipper paper](https://arxiv.org/abs/2303.13408). Checkpoints are released at -https://github.com/google-research/google-research/tree/master/dipper. - diff --git a/t5x-main/docs/overview.md b/t5x-main/docs/overview.md deleted file mode 100644 index d1393125a30d5436bda7d388cc3baa1d4555c2a4..0000000000000000000000000000000000000000 --- a/t5x-main/docs/overview.md +++ /dev/null @@ -1,2 +0,0 @@ -```{include} ../README.md -``` \ No newline at end of file diff --git a/t5x-main/docs/requirements.txt b/t5x-main/docs/requirements.txt deleted file mode 100644 index df2edd4ac0d1ba15f653c27890ef3d635c2a5f47..0000000000000000000000000000000000000000 --- a/t5x-main/docs/requirements.txt +++ /dev/null @@ -1,8 +0,0 @@ -sphinx>=4.4.0 -myst_parser>=0.16.1 -myst_nb -sphinx-design -sphinx-book-theme - -# Must install t5x itself for notebook execution and autodocs to work. -. \ No newline at end of file diff --git a/t5x-main/docs/t5x.png b/t5x-main/docs/t5x.png deleted file mode 100644 index 6430d6eea4fa0fcd687ce9ac297746d0764cb1f0..0000000000000000000000000000000000000000 --- a/t5x-main/docs/t5x.png +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:5e903d6a7cb99b192a23b895cd30157d5661cd0e895b3f1d6f2027fdfb1b66dd -size 1835901 diff --git a/t5x-main/docs/tutorials.md b/t5x-main/docs/tutorials.md deleted file mode 100644 index 30f15fc2e48801d27e14bc48b84cb5d5b6451983..0000000000000000000000000000000000000000 --- a/t5x-main/docs/tutorials.md +++ /dev/null @@ -1,51 +0,0 @@ -# T5X Introductory Tutorial Series - - -## Overview - -This series of guides is a self-contained introduction to T5X, a modular, -composable, research-friendly framework for high-performance, configurable, -self-service training, evaluation, and inference of sequence models (starting -with language) at many scales. - - -## How to Use These Guides - -Most entries in this series are colab notebooks (click the blue banners to the -right of each heading below), allowing you to run our tutorial code -interactively. We encourage you to do that! Play around, change things, see what -happens! - - -## T5X Guides - -### Codelab 1: An Introduction to T5X - -Open in colab
- -In this colab, you will learn about some of the basic T5X components and put -them to use to run training, inference, and evaluation on natural text inputs. - -### Codelab 2: Training Deep Dive - -Open in colab
- -In this colab, you will dive into how to restore T5X models from checkpoints and -run training, while also getting an introduction to the T5X trainer. - -### Codelab 3: Inference Deep Dive - -Open in colab
- -In this colab, you will dive into how the Interactive Model does decoding to -generate predictions and scores for a given input. - -### Codelab 4: Evaluation Deep Dive - -Open in colab
- -In this colab, you will dive into how the InteractiveModel takes a batch of -inputs and targets and runs evaluation to produce various metrics. - - -### More Colabs coming soon! diff --git a/t5x-main/docs/usage/auxiliary.md b/t5x-main/docs/usage/auxiliary.md deleted file mode 100644 index 9bba8ec30448a950deb3b8390aed25e55189ec91..0000000000000000000000000000000000000000 --- a/t5x-main/docs/usage/auxiliary.md +++ /dev/null @@ -1,204 +0,0 @@ -# Auxiliary Job - - -## Introduction - -This page outlines the steps needed to use the auxiliary job capabilities -available in T5X. - -## Overview - -There are a variety of situations in which running a single job is insufficient -or suboptimal. For example, consider the following scenarios: - -+ You want to keep track of evaluation (`infer_eval` or `train_eval`) metrics - per checkpoint, but evaluation takes a very long time due to having a large - eval dataset, slow decoding, or multiple tasks to evaluate. - -+ You want to finetune every checkpoint on a downstream task as you train. - -+ You have customized evaluation code that you want to run on every checkpoint - as you train, but that does not naturally fit within a `seqio.Evaluator` - framework. - -In cases like these, users can make use of the auxiliary job functionality. At a -high-level, the auxiliary job will launch a new job every time a new checkpoint -is saved. This new job can either re-use the `train.py` binary (e.g. for -continuous finetuning) or a different one. For example, this allows users to -perform continuous evaluation (using `eval.py`) without slowing down the -training job. We will provide detailed examples showing how to use the auxiliary -job for these use-cases. - -When this new job is launched, the controller will replace four gin macros: -`MODEL_DIR`, `MIXTURE_OR_TASK_NAME`,`INITIAL_CHECKPOINT_PATH`, `TRAIN_STEPS`. -The second of these is set by the user-controlled flag (more on this below), and -the third one is equal to the last checkpoint seen. Aside from this, users are -free to modify the configuration as needed. Beyond gin macros, the auxiliary job -can also have different resource requirements, priority, and even cell placement -from the train job. - -## Example 1: Separate evaluation job. - -### Step 1: Choose a model architecture. - -Similar to pretraining, we will need some gin configuration. For this example, -we will use the T5-1.1-Base model. - -### Step 2: Choose a SeqIO Task/Mixture for training and evaluation. - -In this example, we will use the classic task of English-French translation from -WMT14, which is conveniently available as a SeqIO task in the tasks file from -the T5 tasks under the name `'wmt_enfr14_v003'`. - -### Step 3: Write a Gin config. - -Unlike pretraining or finetuning, we will need two gin files for this setup: one -for the training job, and one for the auxiliary job. The train gin file will -have the same requirements as the gin file for pretraining or finetuning. The -auxiliary job gin file can leverage these gin files or be its own independent -gin file, depending on the user’s choice. For this example, we will make a new -gin which is mostly a wrapper around `pretrain.gin` with some additional -hardcoded features. We will use this gin file for the train job and `eval.gin` -for the auxiliary job. - -### Step 4: Launch your experiment. - -Our sample script will be quite similar to the one used in pretraining and -finetuning, but with a few additional flags which we describe below. - -+ `auxiliary_job_mixtures`: This is a comma-separated list of mixtures. A - separate auxiliary job will be run for each mixture and will replace the gin - macro `MIXTURE_OR_TASK_NAME`. Note that you need this flag even if you are - using a custom binary, which does not need a mixture since otherwise no - auxiliary job will run. - -+ `auxiliary_job_gin_file`: This is identical to `gin_file`, except it is used - for the auxiliary job instead of the train job. - -+ `replace_gin_file`: If True, this auxiliary launcher will not use any of the - gin files from train job. This is necessary when using a binary different - from `train.py`, since the top-level functions will not match. - -+ `auxiliary_job_cell`: The cell in which to run your job. Note that this can - be different from the training cell. - -+ `auxiliary_job_platform`: The platform to use for the auxiliary. Note that - this can be different from the one use for the train job, allowing users to - use smaller configurations for evaluation than needed for training. - -+ `auxiliary_job_build_target`: The binary to use for auxiliary job. - -+ `final_auxiliary_job_steps`: This flag controls how many additional steps to - take when using the auxiliary job for finetuning. Setting to 0 enables - continuous evaluation. - -We provide the sample script below. - -```sh -declare -a ARGS=( ---cell=iz ---platform=jd=2x2 ---final_auxiliary_job_steps=0 ---replace_gin_file=True ---auxiliary_job_mixtures=wmt14_enfr_v003 ---auxiliary_job_gin_file=t5x/examples/t5/t5_1_1/examples/base_wmt14enfr_eval.gin ---auxiliary_job_cell=iz ---auxiliary_job_platform=jd=2x2 ---auxiliary_job_build_target_path=//t5x:eval ---gin_file=t5x/examples/t5/t5_1_1/examples/base_wmt14enfr_train.gin -) - -gxm t5x/google/xm_launch.py "${ARGS[@]}" -``` - -## Example 2: Continuous finetuning job. - -In this example, we will be pretraining a model on a span corruption task on the -C4 dataset, and finetuning it on the WMT'14 English-French translation task. As -before, we will launch a new auxiliary job once every checkpoint is saved. -However, instead of using the `eval.py` binary, we will use the `train.py` -binary. - -### Step 1: Choose a model architecture. - -We will use the T5-1.1-Base model as in the previous example. - -### Step 2: Choose a SeqIO Task/Mixture for training and evaluation. - -For pretraining, we re-use the span coprruption task `c4_v220_span_corruption` -available in the T5 mixtures `tasks.py` file. - -### Step 3: Write a Gin config. - -As before, we need our gin files to contain all the desired macros in them. We -thus create two new gin files: `base_c4_pretrain.gin` for the train job and -`base_wmtenfr14_finetune.gin` for the auxiliary job. - -### Step 4: Launch your experiment. - -Our script is quite similar to the first example, with the same flags as before -but with the appropiate changes. The main distinction is that we must change the -flag `final_auxiliary_job_steps` to be non-zero to start finetuning. We will -settle for a modest 200 steps for the sake of demonstration (and evaluate every -100 steps), but users should use larger steps in realistic scenarios. We also -use `train.py` binary instead of `eval.py`. - -We provide the sample script below. - -```sh -declare -a ARGS=( ---cell=iz ---platform=jd=2x2 ---final_auxiliary_job_steps=200 ---replace_gin_file=True ---auxiliary_job_mixtures=wmt14_enfr_v003 ---auxiliary_job_gin_file=t5x/examples/t5/t5_1_1/examples/base_wmt14enfr_finetune.gin ---auxiliary_job_cell=iz ---auxiliary_job_platform=jd=2x2 ---auxiliary_job_build_target_path=//t5x:train ---gin_file=t5x/examples/t5/t5_1_1/examples/base_c4_pretrain.gin -) - -gxm t5x/google/xm_launch.py "${ARGS[@]}" -``` - -## Common Gotchas. - -We outline a few common error patterns that we have encountered. - -+ **Not passing a value for the `auxiliary_mixtures` flag.** Even if you have - the desired task in your gin file, or you use a differently named macro, you - should still pass a value for this flag, since launch script will launch a - new job per value of this flag. - -+ **Not setting `replace_gin_file=True` when using a different binary from - train.py.** This will usually yield an error that there is no `train` - function. - -+ **No metrics being logged.** It can be tempting to use gin files usually - used for evaluation. However, one must ensure that the corresponding SeqIO - evaluators still log to the tensorboard, otherwise you won’t see the - metrics. - -+ **Slow `train_eval`.** While the approach outlined above separates out the - infer_eval job, it may be that even train_eval is too slow. In these - situations, we suggest adding the metrics from train_eval into the - `metrics_fn` argument of the SeqIO task and have them be computed in the - auxiliary job as well. To do this with teacher forcing, you will have to use - `train.py` instead of `eval.py`. - -+ **Using `CHECKPOINT_PATH` rather `INITIAL_CHECKPOINT_PATH`.** For legacy - reasons, the auxiliary job uses the macro `INITIAL_CHECKPOINT_PATH` rather - than `CHECKPOINT_PATH` as found in `eval.gin`. Make sure to use the latter - macro building your gin scripts. - -+ **Gin macros being ignored when passed through the format - `gin.{MACRO}={VAL}`.** In the current setup, you must include all gin macros - in the gin script. Attempting to pass them as additional flags will usually - not work. - -+ **Not setting `final_auxiliary_job_steps=0` when performing continuous - evaluation.** The current parameter controller uses this as a check. When - this is true, it will replace the `EVAL_OUTPUT_DIR` folder with the current - `MODEL_DIR`, so that the evaluation metrics are saved in the right place and - the metrics are showed correctly on the tensorboard. diff --git a/t5x-main/docs/usage/decoding.md b/t5x-main/docs/usage/decoding.md deleted file mode 100644 index 89846121e21707e78f081ba1b8aa88d173091923..0000000000000000000000000000000000000000 --- a/t5x-main/docs/usage/decoding.md +++ /dev/null @@ -1,199 +0,0 @@ -# Decoding - - -This page outlines the decoding functions that T5X provides out-of-the-box and -how custom decoding functions can be used for a Transformer model, i.e., an -instance of -[`BaseTransformerModel`](https://github.com/google-research/t5x/blob/main/t5x/models.py?q=symbol:%5CbBaseTransformerModel%5Cb). -Here we refer to decoding as a process of generating a sequence of items from a -fixed alphabet (e.g., generating token ids from the vocabulary). - -There are two major ways to configure the decoding routine. The first method is -to define a decode function that follows the `DecodeFnCallable` signature. This -is more restrictive as it enforces the call signature but users don't need to -modify the model code. - -The second method is to subclass a model class and override -`predict_batch_with_aux` method. While this provides more flexibility, it -requires rewriting the method. - -## Option 1: defining a decoding function - -If a desired decoding process can follow `DecodeFnCallable`, it can be -registered as a private attribute of a -[`BaseTransformerModel`](https://github.com/google-research/t5x/blob/main/t5x/models.py?q=symbol:%5CbBaseTransformerModel%5Cb) -by passing it as a `decode_fn` argument to its constructor. - -### Decoding function call signature - -`DecodeFnCallable` has the following call signature - - -It takes in `inputs`, which is an int32 array with a shape `[batch_size, -max_decode_len]`. This is an input tokens to the decoder. For the standard -encoder-decoder models like T5, this is initialized as zeros with a desired -decoding length. The decoding function will populate the array with the sampled -token ids and return. - -For a decoder-only architectures such as a Prefix Language Model, `inputs` can -be a concatenated sequence of "inputs" and "targets" tokens ids. - -`tokens_to_logits` is a callable that takes in a batch of token ids and the -current autoregressive cache, performs the forward pass and returns the -resulting logits resulting and an updated cache. Note that for incremental -decoding, this function operates with a single token, i.e., the length dimension -is assumed to be 1. - -`DecodeFnCallable` is designed to be as general as possible. This results in -some of the arguments being somewhat generic for a specialized decoding -algorithm. For example, `num_decodes` refers to the number of decoded samples to -be returned. In the case of beam search, `num_decodes` corresponds to what is -commonly known as `beam_size`, with returned sequences sorted by the beam -scores. For temperature sampling, we perform `num_decodes` *independent* -sampling procedures with different random seeds and sort them by the log -probability of the generated sequences. - -For custom decoding functions, there might be additional arguments. To support -these, we provide `**kwargs`. - -Another usage of `**kwargs` is calling `decoding_fn` multiple times without -recompiling the model. This pattern is used in -[Prediction Service](https://github.com/google-research/t5x/blob/main/t5x/google/prediction_service/README.md). -For a compiled model, different values of `alpha` can be passed e.g., -`decoder_params = {"alpha": 0.7}` where `decoder_params` is the argument to -`predict_batch_with_aux`. It is unpacked and passed to `beam_search` function. -Note that the Prediction Service uses -[`predict_batch_with_aux`](https://github.com/google-research/t5x/blob/main/t5x/models.py?q=func:%5Cbpredict_batch_with_aux%5Cb), -which is one of the two public methods. This method is useful if auxiliary -outputs (e.g., scores of the predictions) are to be returned. The other method -is -[`predict_batch`](https://github.com/google-research/t5x/blob/main/t5x/models.py?q=func:%5Cbpredict_batch%5Cb), -which simply returns the predictions. - -### Beam search - -The following lines can be added to a gin file in order to use -[beam search](https://github.com/google-research/t5x/blob/main/t5x/decoding.py;l=881;rcl=446762159) -as a decoding function for an encoder-decoder model. - -```gin -models.EncoderDecoderModel.predict_batch_with_aux.num_decodes = 4 -models.EncoderDecoderModel.decode_fn = @decoding.beam_search -decode.beam_search.alpha = 0.6 -``` - -Note that we skip the gin boilerplate code such as gin dynamic registration. -Please refer to [T5X Gin Primer](gin.md) for more details. - -The beam search behavior is controlled by the arguments passed to `beam_search`. -We provide details for a few of them below. - -#### `num_decodes` - -If `num_decodes` are configured with `gin.register`, it is overridden by the -value explicitly passed by the caller e.g., -`models.EncoderDecoderModel.predict_batch_with_aux`. This is because the -information about `num_decodes` is needed to prepare the encoder inputs and -outputs expanded by `num_decodes` times in the batch dimension. - -We recommend that `num_decodes` be specified *only* in -`models.EncoderDecoderModel.predict_batch_with_aux`. - -#### `alpha` - -This is the brevity penalty introduced in -[Wu et al. 2016](https://arxiv.org/abs/1609.08144) to penalize short sequences. - -#### `max_decode_len` - -For evaluation, we typically don't want to truncate the examples by a specified -sequence length. Therefore, we dynamically obtain the length information from -the batch of examples. The default behavior of `seqio.Evaluator` is to use the -maximum length of a task but, this can be overridden. - -Since the length information is provided dynamically, we don't set -`max_decode_len` in gin. Instead we pass the relevant `inputs` array to -`beam_search` whose length is the dynamically determined maximum length. - -If `max_decode_len` is explicitly specified via gin, this will override the -implicitly determined length information unless it is passed by -`predict_batch_with_aux`. - -### Temperature sampling - -[Temperature sampling](https://github.com/google-research/t5x/blob/main/t5x/decoding.py;l=37;rcl=446762159) -can be used for multiple decoding strategies. The following lines configures -temperature sampling as a `decode_fn`. - -```gin -models.EncoderDecoderModel.predict_batch_with_aux.num_decodes = 1 -models.EncoderDecoderModel.decode_fn = @decoding.temperature_sample -decoding.temperature_sample: - temperature = 0.5 - topk = 20 -``` - -Similar specification can be used for other model types by replacing -`models.EncoderDecoderModel` with the relevant model class, e.g. -`models.PrefixLanguageModel`. - -The sampling behavior is controlled by the arguments passed to -`temperature_sample`. We provide details for a few of them below. - -#### `temperature` - -A probabilistic model outputs a probability distribution over a pre-defined -alphabet. For example, a language model outputs *logits*, which are unnormalized -probability values for each item in the vocabulary. We use a language model as a -running example. A sampling process involves *sampling* from the predicted -distribution one item at a time conditioned on the previously generated items -until a given number of items are generated or a sentinel token that represents -the end of sequence is generated. - -Temperature modifies the unnormalized probability distribution at each step. For -each item $$i$$ in the vocabulary, its probability predicted by the model is -given by - -$$p_i \propto \exp\left(\frac{x_i}{T} \right)$$ - -where $$T$$ is the temperature and $$x_i$$ is the logits value corresponding to -item $$i$$. As $$T \to 0$$, the distribution puts all probability mass to the -item with the highest probability. In other words, the sampling process becomes -a greedy search. - -In the other extreme, as $$T \to \infty$$, the predicted distribution becomes -uniform. - -#### `topk` - -By specifying strictly positive integer value for `topk`, the sampling process -in each step is limited to the `k` items with highest probabilities. `topk` also -uses `temperature` to modify the logits corresponding to the top `k` items. - -#### `topp` - -By specifying non-zero positive float value for `topp`, the sampling process is -limited to a subset of the vocabulary $$V^{(p)} \subset V$$, which is defined by -the smallest set such that - -$$\sum_{i \in V^{(p)}} p_i \ge p$$ - -where $$p_i$$ is the conditional distribution at each time step for item $$i$$. -This is called "Nucleus sampling", which was introduced by -[Holtzman et al. ICLR 2020](https://openreview.net/forum?id=rygGQyrFvH). - -IMPORTANT: Only one of `topk` or `topp` can be used. - -## Option 2: subclassing a model class - -If `DecodeFnCallable` is not flexible enough for your custom decoding function, -you can subclass the model class and override `predict_batch_with_aux` method. -While the model class can be any instance of -[`BaseTransformerModel`](https://github.com/google-research/t5x/blob/main/t5x/models.py?q=symbol:%5CbBaseTransformerModel%5Cb), -we recommend that you subclass the existing models such as -[`EncoderDecoderModel`](https://github.com/google-research/t5x/blob/main/t5x/models.py?q=symbol:%5CbEncoderDecoderModel%5Cb) -and only override `predict_batch_with_aux` method. - -`predict_batch_with_aux` method also has a required call signature, but it is -significantly more flexible. It should return a tuple of predicted sequence -array and auxiliary outputs such as score. diff --git a/t5x-main/docs/usage/eval.md b/t5x-main/docs/usage/eval.md deleted file mode 100644 index f02890f1bc856be8723a6e0573a86d5ca271bcc3..0000000000000000000000000000000000000000 --- a/t5x-main/docs/usage/eval.md +++ /dev/null @@ -1,226 +0,0 @@ -# Evaluating a Model - - -## Introduction - -This page outlines the steps to evaluate a model with T5X on downstream tasks -defined with [SeqIO](https://github.com/google/seqio/blob/main/README.md). - -Refer to this tutorial when you have an existing model that you want to -evaluate. If you would like to fine-tune your model before evaluation, please -refer to the [fine-tuning](finetune.md) tutorial. You can run evals as part of -your fine-tuning run as well. - -## Overview - -Evaluating a model with T5X consists of the following steps: - -1. Choose the model to evaluate. -1. Choose the SeqIO Task/Mixture to evaluate the model on. -1. Write a Gin file that configures the model, SeqIO Task/Mixture and other - details of your eval run. -1. Launch your experiment locally or on XManager. -1. Monitor your experiment and parse metrics. - -These steps are explained in detail in the following sections. An example run -that evaluates a fine-tuned T5-1.1-Small checkpoint on the -[(Open Domain) Natural Questions benchmark](https://ai.google.com/research/NaturalQuestions/) -is also showcased. - -## Step 1: Choose a model - -To evaluate a model, you need a Gin config file that defines the model params, -and the model checkpoint to load from. For this example, a T5-1.1-Small model -fine-tuned on the -[`natural_questions_open_test`](https://github.com/google-research/google-research/tree/master/t5_closed_book_qa/t5_cbqa/tasks.py?l=141&rcl=370261021) -SeqIO Task will be used: - -+ Model checkpoint - - [`cbqa/small_ssm_nq/model.ckpt-1110000`](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/cbqa/small_ssm_nq/) -+ Model Gin file - - [`t5x/configs/models/t5_1_1_small.gin`](https://github.com/google-research/t5x/blob/main/t5x/google/examples/flaxformer_t5/configs/models/t5_1_1_small.gin). - -If you would like to fine-tune your model before evaluation, please follow the -[fine-tuning](finetune.md) tutorial, and continue to Step 2. A list of all -available pre-trained models (with model checkpoints and Gin config files) are -available in the [Models](https://github.com/google-research/t5x/blob/main/docs/models.md) documentation. - -## Step 2: Choose a SeqIO Task/Mixture - -A SeqIO Task encapsulates the data source, the preprocessing logic to be -performed on the data before querying the model, the postprocessing logic to be -performed on model outputs, and the metrics to be computed given the -postprocessed outputs and targets. A SeqIO Mixture denotes a collection of Tasks -and enables fine-tuning a model on multiple Tasks simultaneously. - -Many common datasets and benchmarks, e.g. [GLUE](https://gluebenchmark.com/), -[SuperGLUE](https://super.gluebenchmark.com/), -[WMT](https://www.tensorflow.org/datasets/catalog/wmt_t2t_translate), -[SQUAD](https://rajpurkar.github.io/SQuAD-explorer/), -[CNN/Daily Mail](https://github.com/abisee/cnn-dailymail), etc. have been -implemented as SeqIO Tasks/Mixtures and can be used directly. These -Tasks/Mixtures are defined in -[`t5/data/tasks.py`](https://github.com/google-research/text-to-text-transfer-transformer/tree/main/t5/data/tasks.py) and -[`t5/data/mixtures.py`](https://github.com/google-research/text-to-text-transfer-transformer/tree/main/t5/data/mixtures.py). - -For the example run, you will evaluate the model on the Natural Questions -benchmark, which has been implemented as the `natural_questions_open` Task in -[`/third_party/google_research/google_research/t5_closed_book_qa/t5_cbqa/tasks.py`](https://github.com/google-research/google-research/tree/master/t5_closed_book_qa/t5_cbqa/tasks.py?l=98&rcl=370261021). -Here's an example of a single row of preprocessed data from this Task: - -```python -{ - 'inputs_pretokenized': 'nq question: what was the main motive of salt march', - 'inputs': [3, 29, 1824, 822, 10, 125, 47, 8, 711, 10280, 13, 3136, 10556, 1] - 'targets_pretokenized': 'challenge to British authority', - 'targets': [1921, 12, 2390, 5015, 1], - 'answers': ['challenge to British authority'] -} -``` - -## Step 3: Write a Gin Config - -After choosing the model and SeqIO Task/Mixture for your run, the next step is -to configure your run using Gin. If you're not familiar with Gin, reading the -[T5X Gin Primer](gin.md) is recommended. - -T5X provides a Gin file that configures the T5X eval job (located at -[`t5x/configs/runs/eval.gin`](https://github.com/google-research/t5x/blob/main/t5x/configs/runs/eval.gin)), -and expects a few params from you. These params can be specified in a separate -Gin file, or via commandline flags. Following are the required params: - -+ `CHECKPOINT_PATH`: This is the path to the model checkpoint (from Step 1). - For the example run, set this to - `'gs://t5-data/pretrained_models/cbqa/small_ssm_nq/model.ckpt-1110000'`. -+ `MIXTURE_OR_TASK_NAME`: This is the SeqIO Task or Mixture name to run eval - on (from Step 2). For the example run, set this to - `'natural_questions_open'`. -+ `EVAL_OUTPUT_DIR`: A path to write eval outputs to. When launching using - XManager, this path is automatically set and can be accessed from the - XManager Artifacts page. When running locally using Blaze, you can - explicitly pass a directory using a flag. Launch commands are provided in - the next step. - -In addition to the above params, you will need to import -[`eval.gin`](https://github.com/google-research/t5x/blob/main/t5x/configs/runs/eval.gin) and the -Gin file for the model, which for the example run is -[`t5_1_1_small.gin`](https://github.com/google-research/t5x/blob/main/t5x/google/examples/flaxformer_t5/configs/models/t5_1_1_small.gin). - -```gin -include 'runs/eval.gin' -include 'models/t5_small.gin' -``` - -Note that the `include` statements use relative paths in this example. You will -pass an appropriate `gin_search_paths` flag to locate these files when launching -your run. Absolute paths to Gin files can also be used, e.g. - -```gin -include 't5x/configs/runs/eval.gin' -include 't5x/google/examples/flaxformer_t5/configs/models/t5_1_1_small.gin' -``` - -You will also need to import the Python module(s) that register SeqIO Tasks and -Mixtures used in your run. For the example run, we add `import -google_research.t5_closed_book_qa.t5_cbqa.tasks` -since it is where 'glue_v002_proportional' is registered. - -If you choose a module that is not included as a dependency in the T5X trainer -[binary](https://github.com/google-research/t5x/blob/main/t5x/BUILD;l=76;rcl=398627055), or if you -have defined your gin config file in a location other than the -[T5X config directory](https://github.com/google-research/t5x/blob/main/t5x/configs/), you will -need to follow the instructions in the -[Advanced Topics section](#custom-t5x-binaries) to link in the custom gin file -and/or task definition. - -Note that for most common Task/Mixtures, such as the `glue_v002_proportional` -used in this tutorial, the necessary modules are already included. It is also -possible to skip writing a Gin file and instead pass the params as flags when -launching the eval job (see instructions in Step 4). - -Finally, your Gin file should look like this: - -```gin -include 't5x/configs/runs/eval.gin' -include 't5x/google/examples/flaxformer_t5/configs/models/t5_1_1_small.gin' - -# Register necessary SeqIO Tasks/Mixtures. -import google_research.t5_closed_book_qa.t5_cbqa.tasks - -CHECKPOINT_PATH = 'gs://t5-data/pretrained_models/cbqa/small_ssm_nq/model.ckpt-1110000' -MIXTURE_OR_TASK_NAME = 'natural_questions_open' -``` - -See -[`t5_1_1_small_cbqa_natural_questions.gin`](https://github.com/google-research/t5x/blob/main/t5x/google/examples/flaxformer_t5/configs/examples/eval/t5_1_1_small_cbqa_natural_questions.gin) -for this example. - -In this example, we run the evaluation on one checkpoint. It is common to -evaluate with multiple checkpoints. We provide an easy way to do so *without* -having to recompile the model graph for each checkpoints. This is simply done by -adding `utils.RestoreCheckpointConfig.mode = "all"` to a gin file. Our -`t5x/configs/runs/eval.gin` uses "specific" mode. - -## Step 4: Launch your experiment - -To launch your experiment locally (for debugging only; larger checkpoints may -cause issues), run the following on commandline: - -```sh -EVAL_OUTPUT_DIR="/tmp/model-eval/" -python -m t5x.eval_unfragmented \ - --gin_file=t5x/google/examples/flaxformer_t5/configs/examples/eval/t5_1_1_small_cbqa_natural_questions.gin \ - --gin.EVAL_OUTPUT_DIR=\"${EVAL_OUTPUT_DIR}\" \ - --alsologtostderr -``` - -Note that relative paths can be used to locate the gin files. For that, multiple -comma-separated paths can be passed to the `gin_search_paths` flag, and these -paths should contain all Gin files used or included in your experiment. - - -You can have a look inside -[`eval.gin`](https://github.com/google-research/t5x/blob/main/t5x/configs/runs/eval.gin) to see -other useful parameters that it is possible to pass in, including dataset split, -batch size, and random seed. - -## Step 5: Monitor your experiment and parse metrics - - -After evaluation has completed, you can parse metrics into CSV format using the -following script: - -```sh -EVAL_OUTPUT_DIR= # from Step 4 if running locally, from XManager Artifacts otherwise -VAL_DIR="$EVAL_OUTPUT_DIR/inference_eval" -python -m t5.scripts.parse_tb \ - --summary_dir="$VAL_DIR" \ - --seqio_summaries \ - --out_file="$VAL_DIR/results.csv" \ - --alsologtostderr -``` - -## Next Steps - -Now that you have successfully evaluated a model on the Natural Questions -benchmark, here are some topics you might want to explore next: - -+ [Running inference on a model.](infer.md) -+ [Fine-tuning a model.](finetune.md) -+ [Training a model from scratch.](pretrain.md) - -We also touch upon a few advanced topics related to evaluations below that might -be useful, especially when customizing your eval job. - -## Advanced Topics - - -### Defining a custom SeqIO Task/Mixture to evaluate on {.no-toc} - -Refer to [SeqIO documentation](https://github.com/google/seqio/blob/main/README.md). - -### Defining a custom metric to evaluate - -The best way to define a custom metric is to define a new SeqIO Task/Mixture -that contains this custom metric. Please refer to the SeqIO Documentation on -[custom metrics](https://github.com/google/seqio/blob/main/README.md#metrics). diff --git a/t5x-main/docs/usage/finetune.md b/t5x-main/docs/usage/finetune.md deleted file mode 100644 index 5a27f16f30844d4b567cefe97da0adde682a59ad..0000000000000000000000000000000000000000 --- a/t5x-main/docs/usage/finetune.md +++ /dev/null @@ -1,286 +0,0 @@ -# Fine Tuning a Model - - -## Introduction - -This page outlines the steps to fine-tune an existing pre-trained model with T5X -on common downstream tasks defined with [SeqIO](https://github.com/google/seqio/blob/main/README.md). This is one of -the simplest and most common use cases of T5X. If you're new to T5X, this -tutorial is the recommended starting point. - -## Overview - -Fine-tuning a model with T5X consists of the following steps: - -1. Choose the pre-trained model to fine-tune. -2. Choose the SeqIO Task/Mixture to fine-tune the model on. -3. Write a Gin file that configures the pre-trained model, SeqIO Task/Mixture - and other details of your fine-tuning run. -4. Launch your experiment locally or on XManager. -5. Monitor your experiment and parse metrics. - -These steps are explained in detail in the following sections. An example run -that fine-tunes a T5-small checkpoint on WMT14 English to German translation -benchmark is also showcased. - -## Step 1: Choose a pre-trained model - -To use a pre-trained model, you need a Gin config file that defines the model -params, and the model checkpoint to load from. For your convenience, TensorFlow -checkpoints and Gin configs for common T5 pre-trained models have been made -available for use in T5X. A list of all the available pre-trained models (with -model checkpoints and Gin config files) are available in the -[Models](https://github.com/google-research/t5x/blob/main/docs/models.md) documentation. - -For the example run, you will use the T5 1.1 Small model. The Gin file for this -model is located at -[`/t5x/examples/t5/t5_1_1/small.gin`](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_1/small.gin), -and the checkpoint is located at -[`gs://t5-data/pretrained_models/t5x/t5_1_1_small`](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/t5_1_1_small). - -## Step 2: Choose a SeqIO Task/Mixture - -A SeqIO Task encapsulates the data source, the preprocessing logic to be -performed on the data before querying the model, the postprocessing logic to be -performed on model outputs, and the metrics to be computed given the -postprocessed outputs and targets. A SeqIO Mixture denotes a collection of Tasks -and enables fine-tuning a model on multiple Tasks simultaneously. - -### Standard Tasks - -Many common datasets and benchmarks, e.g. [GLUE](https://gluebenchmark.com/), -[SuperGLUE](https://super.gluebenchmark.com/), -[WMT](https://www.tensorflow.org/datasets/catalog/wmt_t2t_translate), -[SQUAD](https://rajpurkar.github.io/SQuAD-explorer/), -[CNN/Daily Mail](https://github.com/abisee/cnn-dailymail), etc. have been -implemented as SeqIO Tasks/Mixtures and can be used directly. These -Tasks/Mixtures are defined in -[`third_party/py/t5/data/tasks.py`](https://github.com/google-research/text-to-text-transfer-transformer/tree/main/t5/data/tasks.py) -and -[`third_party/py/t5/data/mixtures.py`](https://github.com/google-research/text-to-text-transfer-transformer/tree/main/t5/data/mixtures.py). - -For the example run, you will fine-tune the model on the WMT14 English to German -translation benchmark, which has been implemented as the -[`wmt_t2t_ende_v003`](https://github.com/google-research/text-to-text-transfer-transformer/tree/main/t5/data/tasks.py;l=209;rcl=417815592) -Task. - -### Custom Tasks - -It is also possible to define your own custom task. See the -[SeqIO documentation](https://github.com/google/seqio/blob/main/README.md) for how to do this. As a note, Tasks -defined using the -[old T5 codebase](https://github.com/google-research/text-to-text-transfer-transformer/tree/main/t5/data/dataset_providers.py) -may also be used by T5X. If using a custom Task, you will need to follow the -instructions in the [Advanced Topics section](#custom-t5x-binaries) at the end -of this tutorial to make sure the module containing your task is included. - -When defining a custom task, you have the option to cache it on disk before -fine-tuning. The instructions for this are -[here](https://github.com/google/seqio/blob/main/README.md#optional-offline-caching). Caching may improve -performance for tasks with expensive pre-processing. By default, T5X expects -tasks to be cached. To finetune on a task that has not been cached, set -`--gin.USE_CACHED_TASKS=False`. - -## Step 3: Write a Gin Config - -After choosing the pre-trained model and SeqIO Task/Mixture for your run, the -next step is to configure your run using Gin. If you're not familiar with Gin, -reading the [T5X Gin Primer](gin.md) is recommended. - -T5X provides a Gin file that configures the T5X trainer for fine-tuning (located -at -[`t5x/configs/runs/finetune.gin`](https://github.com/google-research/t5x/blob/main/t5x/configs/runs/finetune.gin)), -and expects a few params from you. These params can be specified in a separate -Gin file, or via commandline flags. Following are the required params: - -+ `INITIAL_CHECKPOINT_PATH`: This is the path to the pre-trained checkpoint - (from Step 1). For the example run, set this to - `'gs://t5-data/pretrained_models/t5x/t5_1_1_small/checkpoint_1000000'`. -+ `TRAIN_STEPS`: Number of fine-tuning steps. This includes the number of - steps that the model was pre-trained for, so make sure to add the step - number from the `INITIAL_CHECKPOINT_PATH`. For the example run, to fine-tune - for `20_000` steps, set this to `1_020_000`, since the initial checkpoint is - the `1_000_000`th step. -+ `MIXTURE_OR_TASK_NAME`: This is the SeqIO Task or Mixture name to run (from - Step 2). For the example run, set this to `'wmt_t2t_ende_v003'`. -+ `TASK_FEATURE_LENGTHS`: This is a dict mapping feature key to maximum int - length for that feature. After preprocessing, features are truncated to the - provided value. For the example run, set this to `{'inputs': 256, 'targets': - 256}`. -+ `MODEL_DIR`: A path to write fine-tuned checkpoints to. When launching using - XManager, this path is automatically set and can be accessed from the - XManager Artifacts page. When running locally using Blaze, you can - explicitly pass a directory using a flag. Launch commands are provided in - the next step. -+ `LOSS_NORMALIZING_FACTOR`: When fine-tuning a model that was pre-trained - using Mesh Tensorflow (e.g. the public T5 / mT5 / ByT5 models), this should - be set to `pretraining batch_size` * `pretrained target_token_length`. For - T5 and T5.1.1: `2048 * 114`. For mT5: `1024 * 229`. For ByT5: `1024 * 189`. - -In addition to the above params, you will need to include -[`finetune.gin`](https://github.com/google-research/t5x/blob/main/t5x/configs/runs/finetune.gin) -and the Gin file for the pre-trained model, which for the example run is -[`t5_1_1/small.gin`](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_1/small.gin). - -```gin -include 't5x/configs/runs/finetune.gin' -include 't5x/examples/t5/t5_1_1/small.gin' -``` - -You will also need to import the Python module(s) that register SeqIO Tasks and -Mixtures used in your run. For the example run, we add `import t5.data.tasks` -since it is where `wmt_t2t_ende_v003` is registered. - - -Finally, your Gin file should look like this: - -```gin -include 't5x/configs/runs/finetune.gin' -include 't5x/examples/t5/t5_1_1/small.gin' - -# Register necessary SeqIO Tasks/Mixtures. -import t5.data.tasks - -MIXTURE_OR_TASK_NAME = "wmt_t2t_ende_v003" -TASK_FEATURE_LENGTHS = {"inputs": 256, "targets": 256} -TRAIN_STEPS = 1_020_000 # 1000000 pre-trained steps + 20000 fine-tuning steps. -DROPOUT_RATE = 0.0 -INITIAL_CHECKPOINT_PATH = "gs://t5-data/pretrained_models/t5x/t5_1_1_small/checkpoint_1000000" -LOSS_NORMALIZING_FACTOR = 233472 -``` - -See -[`t5x/examples/t5/t5_1_1/examples/small_wmt_finetune.gin`](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_1/examples/small_wmt_finetune.gin) -for this example. - - -## Step 4: Launch your experiment - -To launch your experiment locally (for debugging only; larger checkpoints may -cause issues), run the following on commandline: - -```sh -MODEL_DIR="/tmp/finetune-model/" -python -m t5x.train_unfragmented \ - --gin_file=t5x/examples/t5/t5_1_1/examples/small_wmt_finetune.gin \ - --gin.MODEL_DIR=\"${MODEL_DIR}\" \ - --alsologtostderr -``` - -Note that multiple comma-separated paths can be passed to the `gin_search_paths` -flag, and these paths should contain all Gin files used or included in your -experiment. - - -After fine-tuning has completed, you can parse metrics into CSV format using the -following script: - -```sh -MODEL_DIR= # from Step 4 if running locally, from XManager Artifacts otherwise -VAL_DIR="$MODEL_DIR/inference_eval" -python -m t5.scripts.parse_tb \ - --summary_dir="$VAL_DIR" \ - --seqio_summaries \ - --out_file="$VAL_DIR/results.csv" \ - --alsologtostderr -``` - -### Metric Explanations - -By default, t5x logs many metrics to TensorBoard, many of these seem similar but -have important distinctions. - -The first two graphs you will see are the `accuracy` and `cross_ent_loss` -graphs. These are the *token-level teacher-forced* accuracy and cross entropy -loss respectively. Each of these graphs can have multiple curves on them. The -first curve is the `train` curve. This is calculated as a running sum than is -then normalized over the whole training set. The second class of curves have the -form `training_eval/${task_name}`. These curves are created by running a subset -(controlled by the `eval_steps` parameter of the main train function) of the -validation split of `${task_name}` through the model and calculating these -metrics using teacher-forcing. These graphs can commonly be used to find -"failure to learn" cases and as a warning sign of overfitting, but these are -often not the final metrics one would report on. - -The second set of graphs are the ones under the collapsible `eval` section in -TensorBoard. These graphs are created based on the `metric_fns` defined in the -SeqIO task. The curves on these graphs have the form -`inference_eval/${task_name}`. Values are calculated by running the whole -validation split through the model in inference mode, commonly auto-regressive -decoding or output scoring. Most likely these are the metrics that will be -reported. - -More information about the configuration of the datasets used for these -different metrics can be found [here](#train-train-eval-and-infer-eval). - -In summary, the metric you actually care about most likely lives under the -`eval` tab rather, than in the `accuracy` graph. - -## Next Steps - -Now that you have successfully fine-tuned a pre-trained model on WMT, here are -some topics you might want to explore next: - -+ [Evaluating a fine-tuned model.](eval.md) -+ [Running inference on a fine-tuned model.](infer.md) -+ [Training a model from scratch.](pretrain.md) - -We also touch upon a few advanced topics related to fine-tuning below that might -be useful, especially when customizing your fine-tuning job. - -## Advanced Topics - -### `train`, `train_eval` and `infer_eval` {.no-toc} - -A -[`DatasetConfig`](https://github.com/google-research/t5x/blob/main/t5x/utils.py?l=113&rcl=375475889) -object is used to configure loading SeqIO Tasks/Mixtures for training and eval. -If you take a closer look at -[`runs/finetune.gin`](https://github.com/google-research/t5x/blob/main/t5x/configs/runs/finetune.gin), -you will see that there are three `DatasetConfig` objects defined and passed to -the train function: `train_dataset_cfg`, `train_eval_dataset_cfg`, -`infer_eval_dataset_cfg`. Here's a brief description of these configs: - -+ `train`: This configures the Task/Mixture that the model will be fine-tuned - on. -+ `train_eval`: This configures the Task/Mixture that is used to compute - training metrics on the eval split, e.g. perplexity. These metrics are - defined in the - [`Model`](https://github.com/google-research/t5x/blob/main/t5x/models.py;l=257-267;rcl=394045248) - class and the eval fn is located - [here](https://github.com/google-research/t5x/blob/main/t5x/trainer.py;l=257;rcl=398487394). -+ `infer_eval`: This configures the Task/Mixture that is used to compute - metrics on inferred model outputs (e.g., comparing decoded model outputs and - targets). These metrics are defined in the SeqIO Task/Mixture and the eval - fn is located - [here](https://github.com/google/seqio/tree/main/seqio/evaluation.py?l=423&rcl=373643592) - -### Using separate SeqIO Tasks/Mixtures for fine-tuning and eval {.no-toc} - -Commonly, the same SeqIO Task/Mixture is used for training and eval. It is set -by the `MIXTURE_OR_TASK_NAME` macro in your fine-tune Gin file from Step 3 -above, and is passed to `train_dataset_cfg`, `train_eval_dataset_cfg`, -`infer_eval_dataset_cfg`. The `train` split is used for training and the -`validation` split is used for evals. However, you can override these params in -your fine-tune Gin config. For example, if you want to fine-tune on all GLUE -tasks but evaluate only on GLUE STS benchmark, you can override the SeqIO -Task/Mixture used for `infer_eval` in your fine-tune Gin file as follows: - -```gin -include 'runs/finetune.gin' -include 'models/t5_small.gin' - -MIXTURE_OR_TASK_NAME = 'glue_v002_proportional' -MIXTURE_OR_TASK_MODULE = 't5.data.tasks' -TASK_FEATURE_LENGTHS = {'inputs': 512, 'targets': 84} -TRAIN_STEPS = 1_500_000 # includes 1_000_000 pretrain steps -INITIAL_CHECKPOINT_PATH = 'gs://t5-data/pretrained_models/t5x/t5_small/checkpoint_1000000' -infer_eval/utils.DatasetConfig.mixture_or_task_name = 'glue_stsb_v002' -``` - -Other params in `finetune.gin` can be overridden in the same way. - - -### Defining a custom SeqIO Task/Mixture to fine-tune on {.no-toc} - -Refer to [SeqIO documentation](https://github.com/google/seqio/blob/main/README.md). diff --git a/t5x-main/docs/usage/gin.md b/t5x-main/docs/usage/gin.md deleted file mode 100644 index 3525b734da747306e27a539dcc6095e35faf6b42..0000000000000000000000000000000000000000 --- a/t5x-main/docs/usage/gin.md +++ /dev/null @@ -1,395 +0,0 @@ -# Gin Primer - - -[Gin](https://github.com/google/gin-config/blob/main/README.md) is a lightweight configuration framework for Python, -based on dependency injection. While T5X does not employ gin in its core -libraries, it is used to configure runs of the `train`, `eval`, and `infer` -scripts. This usage is a bit different (and more limited) than how gin is -typically applied, so this primer should be useful even for those who may be -familiar with gin from other libaries (e.g., T5 or Mesh TensorFlow). - -Nevertheless, you may still find it helpful to refer to the -[gin documentation](https://github.com/google/gin-config/blob/main/README.md) for more background. - -[TOC] - -## Gin in T5X Scripts - -Rather than plumbing run arguments and hyperparameters through via limited set -of command-line flags or a flat configuration schema, T5X's gin integration -allows you to parameterize the top-level run functions (`train`, `evaluate`, and -`infer`) as well as any object or function that is passed to them. This enables -a vast amount of flexibility over your runs without needing to modify any code -within the core T5X library. - -For example, you can implement a Python class in your own codebase (e.g., a -custom model or trainer) and use gin to pass an instance of it to the T5X XM -launcher without having to fork any code. Previously you needed to implement -every experimental idea in the core library (no matter how widely used it would -be) and add a ConfigDict flag to enable/disable it, resulting in significant -code debt over time. - -On the other hand, gin can sometimes be too powerful, allowing users the ability -to bind arguments throughout a codebase, which makes it difficult or impossible -to update "private" internal interfaces. However, by limiting configurability to -a single top-level function and its arguments we can better control the -configurable surface to public interfaces and user-owned code, and also avoid -unintended side effects. - -### An Example - -Let's look at the `evaluate` call signature from -[eval.py](https://github.com/google-research/t5x/blob/main/t5x/eval.py) as an example: - -```py -def evaluate(*, - model: models.BaseModel, - dataset_cfg: utils.DatasetConfig, - restore_checkpoint_cfg: utils.RestoreCheckpointConfig, - partitioner: partitioning.BasePartitioner, - output_dir: str): - """Evaluation function. - - Args: - model: The model object to use for inference. - dataset_cfg: Specification for the dataset to infer based on. - restore_checkpoint_cfg: Specification for the model parameter checkpoint to - load. - partitioner: The partitioner for the model parameters and - data across devices. - output_dir: Path to directory to write temporary files and final results. - """ - ... -``` - -In the binary, the user-provided gin configuration file will be parsed. It -specifies which values should be bound to the `evaluate` argument, after which -we can directly call the fully-bound function without any arguments. Basically, -we are creating a custom closure of `evaluate` (a la `functools.partial`) but -specifying the arguments via gin instead of Python. - -Furthermore, this ability to bind custom arguments is recursive. Not only can we -bind the arguments of `evaluate`, but we can also bind the constructor and -method arguments of the instance of `models.BaseModel` that we pass to -`evaluate`. - -Let's now look at an example of a gin configuration for parameterizing -`evaluate`, specifically evaluating a -[T5 model fine-tuned for closed book question answering](http://goo.gle/t5-cbqa) -on [Natural Questions Open](https://ai.google.com/research/NaturalQuestions): - -```py -from __gin__ import dynamic_registration - -import __main__ as eval_script -from t5x import models -from t5x import partitioning -from t5x import utils - -MODEL = %gin.REQUIRED - -eval_script.evaluate: - model = %MODEL - output_dir = '/tmp/t5x_eval' - dataset_cfg = @utils.DatasetConfig() - partitioner = @partitioning.PjitPartitioner() - restore_checkpoint_cfg = @utils.RestoreCheckpointConfig() - -# Load model with overrides. -include 'models/t5_large.gin' -models.EncoderDecoderModel.predict_batch_with_aux.num_decodes = 1 - -utils.DatasetConfig: - mixture_or_task_name = 'natural_questions_open' - split = 'test' - task_feature_lengths = None - batch_size = 32 - shuffle = False - seed = 0 - use_cached = False - pack = False - use_custom_packing_ops = False - module = 'google_research.t5_closed_book_qa.t5_cbqa.tasks' - -partitioning.PjitPartitioner: - num_partitions = 1 - -utils.RestoreCheckpointConfig: - mode = 'specific' - path = 'gs://t5-data/pretrained_models/cbqa/large_ssm_nqo' - assignment_map = None - strict = True - dtype = None -``` - -Let's go through this block-by-block. - -```py -from __gin__ import dynamic_registration -``` - -The first line imports a new gin feature (see cl/372624800 for more details) to -allow us to register functions and objects for configuration from within the gin -file itself without having to modify or decorate functions from the imported -packages. - -```py -import __main__ as eval_script -from t5x import models -from t5x import utils -``` - -The second block imports the modules containing the components we plan to -configure in this file and is required for dynamic registration. Note that only -those functions and objects that we specify below will actually be configured, -not everything in the module. Also, as is the case in Python, the binary module -is referred as `__main__`, although we rename it to `eval_script` for clarity in -the rest of the config. - -```py -MODEL = %gin.REQUIRED -``` - -The third block creates a -[gin macro](https://github.com/google/gin-config/tree/master/docs/index.md#gin-macros) -(essentially a lazy reference) and for now sets it to refer to the special macro -`gin.REQUIRED`, which will cause a failure during parsing of the configuration -if not updated via a later assignment in the config file or command-line flags -(see [below](#command-line-usage)). - -```py -eval_script.evaluate: - model = %MODEL - output_dir = '/tmp/t5x_eval' - dataset_cfg = @utils.DatasetConfig() - partitioner = @partitioning.PjitPartitioner() - restore_checkpoint_cfg = @utils.RestoreCheckpointConfig() -``` - -The fourth block specifies the binding for the `evaluate` function. For `model`, -we pass the value of the `MODEL` macro (to be defined later). For `output_dir` -we pass a string path. For `dataset_cfg`, `restore_checkpoint_cfg`, and -`partitioner`, we pass instantiations of `DatasetConfig`, -`RestoreCheckpointConfig`, and `PjitPartitioner`, which are defined in -[utils.py](https://github.com/google-research/t5x/blob/main/t5x/utils.py) and -[partitioning.py](https://github.com/google-research/t5x/blob/main/t5x/partitioning.py) -respectively. The '@' prefix tells gin that the following is a configured -function or class, and the '()' suffix signifies that it should be called (in -the cases of class, this means calling the constructor). If we wanted to pass in -the closure (or a partially bound) function instead of its return value, we -would leave off the parentheses. - -The remainder of the file deals with defining the `MODEL` macro and fully -binding these constructors. - -```py -# Load model with overrides. -include 't5x/examples/t5/t5_1_1/large.gin' -models.EncoderDecoderModel.predict_batch_with_aux.num_decodes = 1 -``` - -Although we could define `MODEL = model.EncoderDecoderModel()` here, we prefer -to create a separate gin file that defines it. This makes it easier to reuse -parts of the common configurations. All of the bindings in the newly included -file are read and override any conflicting ones defined so far in this file. -It's equivalent to copy and pasting the contents of the included file at this -location in the config. If you want to see how the model itself is instantiated, -you can refer to -[t5_1_1/large.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_1/large.gin) -(which simply overrides a few values from -[t5_1_1/base.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_1/base.gin)). - -The final line of this block shows an example of how you can modify the default -arguments of the `EncoderDecoderModel` instance referenced by `%MODEL`, in this -case changing the default beam size it will use during prediction. Notice that -since we are only binding one argument here, we choose to write it on a single -line instead of using the block binding syntax used elsewhere in the file. - -```py -utils.DatasetConfig: - mixture_or_task_name = 'natural_questions_open' - split = 'test' - task_feature_lengths = None - batch_size = 32 - shuffle = False - seed = 0 - use_cached = False - pack = False - use_custom_packing_ops = False - module = 'google_research.t5_closed_book_qa.t5_cbqa.tasks' - -partitioning.PjitPartitioner: - num_partitions = 1 - -utils.RestoreCheckpointConfig: - mode = 'specific' - path = 'gs://t5-data/pretrained_models/cbqa/large_ssm_nqo' - assignment_map = None - strict = True - dtype = None -``` - -The last 3 blocks are fairly straightforward. They are effectively setting the -attributes of these dataclasses by binding values to their constructors that -will be used when they are instantiated and passed to `evaluate`, as specified -in the fourth block. - -### Scoping - -The above example lacks one key component of gin: -[scopes](https://github.com/google/gin-config/blob/main/README.md#4-configuring-the-same-function-in-different-ways-scopes). - -What happens if you need to use a class or function multiple times but with -different bound values? - -A clear example of this is in the top-level `train` function (in -[train.py](https://github.com/google-research/t5x/blob/main/t5x/train.py)). The call signature -includes 3 different instances of `utils.DatasetConfig`: one for the train -dataset, one for the "train-eval" dataset (used for evaluation with teacher -forcing), and one for the "infer-eval" dataset (used for evaluation with -inference/decoding). - -The solution is to prefix each instance with a unique identifier both when -specifying where it is to be passed to `train` and when binding its arguments. -For example, the gin file might look like the following (skipping the irrelevant -bits): - -```py -... - -train_script.train: - train_dataset_cfg = @train/utils.DatasetConfig() - train_eval_dataset_cfg = @train_eval/utils.DatasetConfig() - infer_eval_dataset_cfg = @infer_eval/utils.DatasetConfig() - ... - -train/utils.DatasetConfig: - mixture_or_task_name = 'train_mixture' - split = 'train' - ... - -train_eval/utils.DatasetConfig: - mixture_or_task_name = 'eval_mixture' - split = 'validation' - ... - -infer_eval/utils.DatasetConfig: - mixture_or_task_name = 'eval_mixture' - split = 'test' - ... -``` - -We have therefore configured 3 different scoped-versions of -`utils.DatasetConfig` producing 3 separate instances that are passed to `train`. - -Note that these three scopes will all inherit from the base scope, so if you -want to set a shared binding, you may directly configure `utils.DatasetConfig` -without a scope prefix. - -## Command-Line Usage - -So now that you have a gin config, how do you pass it to the script? There are -two ways: gin files and override flags. - -1. **Gin Files** You have already seen an example of a gin file above. You can - specify the gin file(s) to use in your script via the `--gin_file` flag. If - you want to load multiple gin files, you can set the flag multiple times and - the files will be loaded in order, with the second potentially overriding - the first when there are conflicts. It is possible to supply a - comma-separate list of search prefixes via `--gin_search_paths` and then - only specify the relative path to the `--gin_file` flags. However, we - strongly recommend against using `--gin_search_paths`. Using absolute paths - via the `--gin_file` flags will reduce sources of ambiguity and improve the - consistency of your scripts. - -1. **Override Flags** Gin flags allow for more fine-grained overrides of any - configurable aspect of your run. These flags follow the single-line binding - format from the above example with the addition of a `--gin.` prefix. For - example, if you want to override the dataset shuffling, you can set - `--gin.utils.DatasetConfig.shuffle=False`. In the train setting where there - are multiple datasets, you must supply the appropriate scope, e.g., - `--gin.train/utils.DatasetConfig.shuffle=False`. These bindings are - processed in order *after* the gin files are loaded, and therefore overwrite - any previously assigned value in the gin files. - -**Note:** when supplying a string, dict, list, or tuple value via a flag, you -must put it in quotes. In the case of strings, it requires escaped quotes -(`\"\"`). For example: `--gin.utils.DatasetConfig.split=\"validation\"`, -`--gin.utils.DatasetConfig.task_feature_lengths="{'inputs': 512, 'targets': -84}"`, and `--gin.dense.MlpBlock.activations="('dense', 'gelu')"` - -### An Example - -An example where you may need multiple files is with the `train` script. - -You can first specify which model you want to train by supplying a gin file -containing its definition, for example: -[t5_1_1/small.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_1/small.gin). - -You may then specify a run config that supplies some of the common defaults. For -example, if you are doing pretraining you can use -[runs/pretrain.gin](https://github.com/google-research/t5x/blob/main/t5x/configs/runs/pretrain.gin), -and if you are doing finetuning, you can use -[runs/finetune.gin](https://github.com/google-research/t5x/blob/main/t5x/configs/runs/finetune.gin). - -We can apply these two files with the following command: - -```sh -python -m t5x.train_unfragmented \ - --gin_file=t5x/examples/t5/t5_1_1/small.gin \ - --gin_file=t5x/configs/runs/finetune.gin \ - --logtostderr -``` - -However, running this command will give you an error like the following: - -```sh -ValueError: MODEL_DIR/macro.value set to `%gin.REQUIRED` but not subsequently overridden. -``` - -This is because the config still includes some `gin.REQUIRED` macros that you'll -need to override with the details of your run. At the top of -[runs/finetune.gin](https://github.com/google-research/t5x/blob/main/t5x/configs/runs/finetune.gin) -you'll see the list of required overrides, which we will populate for finetuning -on WMT in the updated launch command here: - -```sh -python -m t5x.train_unfragmented \ - --gin_file=t5x/examples/t5/t5_1_1/small.gin \ - --gin_file=t5x/configs/runs/finetune.gin \ - --gin.MIXTURE_OR_TASK_NAME=\"wmt_t2t_ende_v003\" \ - --gin.MIXTURE_OR_TASK_MODULE=\"t5.data.mixtures\" \ - --gin.TASK_FEATURE_LENGTHS="{'inputs': 256, 'targets': 256}" \ - --gin.TRAIN_STEPS=1_020_000 \ - --gin.MODEL_DIR=\"/tmp/t5_1_1_base_finetune_gin\" \ - --gin.INITIAL_CHECKPOINT_PATH=\"gs://t5-data/pretrained_models/t5x/t5_1_1_small/checkpoint_1000000\" \ - --logtostderr -``` - -Note you may still override any registered bindings. For example, to disable -inference evaluation you may add `--gin.train.infer_eval_dataset_cfg=None`. - -### A File-only Example - -At the beginning of the primer, we saw a fully-specified run config. We can do -something similar with the previous example to create a self-contained run -configuration. -[t5_1_1/examples/small_wmt_finetune.gin](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_1/examples/small_wmt_finetune.gin) -is just such an example that allows you to exactly duplicate the previous launch -command simply by calling: - -```sh -python -m t5x.train_unfragmented \ - --gin_file=t5x/examples/t5/t5_1_1/examples/small_wmt_finetune.gin \ - --gin.MODEL_DIR=\"/tmp/t5_1_1_small_finetune_gin\" \ - --logtostderr -``` - -## Logging - -After your gin files and flag overrides are parsed, the complete configuration -will be logged to INFO, written to `config.gin` in the output directory, and -added to a TensorBoard summary. - -It is highly recommended that you review this generated config to ensure that -your overrides are working as expected. diff --git a/t5x-main/docs/usage/gpu-usage.md b/t5x-main/docs/usage/gpu-usage.md deleted file mode 100644 index dedcd88ab1da45a763948519dcb10a24f2379c38..0000000000000000000000000000000000000000 --- a/t5x-main/docs/usage/gpu-usage.md +++ /dev/null @@ -1,87 +0,0 @@ -# GPU Scripts - -# Warning! -An updated version of T5x with optimized GPU performance (18-80% perf gains!) and new features, including FP8 with [Transformer Engine](https://github.com/NVIDIA/TransformerEngine) and H100 support can be found here: [NVIDIA Rosetta](https://github.com/NVIDIA/JAX-Toolbox/tree/main/rosetta/rosetta/projects/t5x). ------ -**NVIDIA no longer recommends using this repository and won't be updating it further.** ------ - -The [t5x/contrib/gpu](../../t5x/contrib/gpu) directory contains scripts optimized for GPU usage. - -Install with `pip install -r pile_requirements.txt` to get all pile dependencies. - -## Building the container -The Dockerfile in `t5x/contrib/gpu` given will build a container with all gpu/pile dependencies. It can be built with `t5x/contrib/gpu/docker/build.sh ` - -## Running interactively -Note: this should only be done with singlenode jobs and/or for downloading the pile. Use `t5x/contrib/gpu/docker/interactive_pull_and_launch.sh`. This takes arguments for the URL to pull a container from and the location of the dataset directory to mount. For example: - -`t5x/contrib/gpu/docker/interactive_pull_and_launch.sh [URL] /my/dataset/dir` - -## Downloading The Pile -Run `download_the_pile.py` to download the pile. It will download to the directory set in the environment variable: `TFDS_DATA_DIR`. After that, set the `TFDS_DATA_DIR` to the same directory in your scripts to use. - -## Single Node runs -Pretraining and Finetuning can be done with `singlenode_*.sh`. These will build a T5X model with the Adam optimizer and relevant parameters. These will allow multi-gpu on one host. - -## Multi Node runs -For a SLURM+pyxis cluster, `example*.sub` files provide example slurm submit files (edit with your details), which call `multiprocess*.sh` to execute training. You can add a binding script in the `.sub` file for your cluster, or remove it entirely (dropping some throughput) - -## Convergence -For our Pile convergence runs, we used a Global batch size of 2304 for XXL and 2048 for all other models, where GBS is defined as #GPUs * BS/GPU / Tensor Parallel(TP). Below are example (tested) hardware topologies on NVIDIA DGX A100 (8x A100 80G) nodes. - -| size | #GPUs | TP | BS / GPU | Sequences/Sec | Estimated Walltime | MNLI 2.0 - matched | SQuAD v1.1 (EM/F1) | Convergence Log | -| ---- | ----- | ----- | -------- | ------------- | ------------------ | ------------------ | ------------------ | --------------- | -| small| 8 | 1 | 256 | ~3168 | 7.48 days | 83.06% | 78.33 / 86.63 | [log](https://tensorboard.dev/experiment/lWnHal7PRnOLeZuewyWVxQ/#scalars&_smoothingWeight=0) | -| large| 64 | 1 | 32 | ~3886 | 6.10 days | 90.50% | 87.31 / 94.04 | [log](https://tensorboard.dev/experiment/aOxJBIvTQBeTJ8XGXxaL6Q/#scalars&_smoothingWeight=0) | -| xl | 256 | 1 | 8 | ~3652 | 6.49 days | 91.15% | 89.36 / 95.29 | [log](https://tensorboard.dev/experiment/vuRoEYgkRgWiEtbvgxlOqw/#scalars&_smoothingWeight=0) | -| xxl | 512 | 8 | 36 | ~1346 | 19.81 days | N/A(partial run) | N/A(partial run) | N/A(partial run)| - -Note: Convergence (as shown in log) was not necessarily done with the hardware topology listed, but the listed topology is tested. Estimated Walltime is calculated assuming full throughput (seq/sec) continuously. In practice, there are compilation overheads at the beginning of each run/restart(in cluster settings) + checkpointing overheads (if any). - -(More perf improvements coming soon!) - -Other hyperparameters are specified in the associated pile `gin` files in the `contrib/gpu/t5/t5_1_1/examples` directory. - -## Pretraining run commands - -### Singlenode -small: - -`t5x/contrib/gpu/t5/scripts_gpu/singlenode_pretrain_pile.sh small bfloat16 8 256 {LOGDIR - create before running} {MODEL_DIR} {GRADIENT_ACCUMULATION (1 by default)}` - -Finetuning: -MNLI v2: -`t5x/contrib/gpu/t5/scripts_gpu/singlenode_ft_frompile.sh mnli2 small bfloat16 8 256 {LOGDIR - create before running} {MODEL_DIR(to restore pretrained checkpoint from)} {GRADIENT_ACCUMULATION}` - - -### Multinode -Arguments are as such: - -`sbatch -N {NODE_CT} t5x/contrib/gpu/t5/scripts_gpu/example_slurm_pretrain_pile.sub {MODEL_SIZE} {MODEL_PREC} {GPU/NODE} {BS/GPU} {MODEL_DIR} {GRADIENT_ACCUMULATION} {TENSOR_PARALLEL}` - -small: - -`sbatch -N 1 t5x/contrib/gpu/t5/scripts_gpu/example_slurm_pretrain_pile.sub small bfloat16 8 256 {MODEL_DIR} 1 1` - -large: - -`sbatch -N 8 t5x/contrib/gpu/t5/scripts_gpu/example_slurm_pretrain_pile.sub large bfloat16 8 32 {MODEL_DIR} 1 1` - -xl: - -`sbatch -N 32 t5x/contrib/gpu/t5/scripts_gpu/example_slurm_pretrain_pile.sub xl bfloat16 8 8 {MODEL_DIR} 1 1` - -Finetuning commands simply change the script and have an additional `{FT_TASK}` as the first argument (along with relevant hyperparameter changes). Your `MODEL_DIR` should contain the pretrained checkpoint to restore from. - -MNLI v2: - -`sbatch -N {NODE_CT} t5x/contrib/gpu/t5/scripts_gpu/example_slurm_ft_frompile.sub mnli2 {MODEL_SIZE} {MODEL_PREC} {GPU/NODE} {BS/GPU} {MODEL_DIR} {GRADIENT_ACCUMULATION} {TENSOR_PARALLEL}` - -SQuAD v1.1 - -`sbatch -N {NODE_CT} t5x/contrib/gpu/t5/scripts_gpu/example_slurm_ft_frompile.sub squad1 {MODEL_SIZE} {MODEL_PREC} {GPU/NODE} {BS/GPU} {MODEL_DIR} {GRADIENT_ACCUMULATION} {TENSOR_PARALLEL}` - -On all finetuning runs, we use a Global Batch Size of 128 with bfloat16 precision. - -WARNING: Finetuning is configured by default to save every checkpoint and delete none (to avoid accidentally deleting your pretrained checkpoint). Watch your disk space! This behavior can be changed in `t5x/configs/runs/finetune_{TASK}.gin`, however this puts the pretrained checkpoint at risk unless backed up. \ No newline at end of file diff --git a/t5x-main/docs/usage/index.rst b/t5x-main/docs/usage/index.rst deleted file mode 100644 index 8debe023e857de352f220a7313d958a347de6e7a..0000000000000000000000000000000000000000 --- a/t5x-main/docs/usage/index.rst +++ /dev/null @@ -1,16 +0,0 @@ -T5X Usage Guides -================ - -.. toctree:: - :maxdepth: 2 - - pretrain.md - finetune.md - eval.md - infer.md - auxiliary.md - decoding.md - metrics.md - partitioning.md - gin.md - diff --git a/t5x-main/docs/usage/infer-files.md b/t5x-main/docs/usage/infer-files.md deleted file mode 100644 index 95d88c25cd1b736de711a4007e06ac939b80284d..0000000000000000000000000000000000000000 --- a/t5x-main/docs/usage/infer-files.md +++ /dev/null @@ -1,217 +0,0 @@ -# Running inference on a Model - - -## Introduction - -This page outlines the steps to run inference a model with T5X on files -containing -[TensorFlow Examples](https://www.tensorflow.org/api_docs/python/tf/train/Example). - -## Overview - -Running inference on a model with T5X using TF Example files consists of the -following steps: - -1. Choose the model to run inference on. -1. Choose the TF Example files to run inference on. -1. Write a Gin file that configures the model, file source and other details of - your inference run. -1. Launch your experiment locally or on XManager. -1. Monitor your experiment and access predictions. - -These steps are explained in detail in the following sections. An example run -that runs inference on a fine-tuned T5-1.1-Small checkpoint on `tfrecord` files -containing the -[(Open Domain) Natural Questions benchmark](https://ai.google.com/research/NaturalQuestions/) -is also showcased. - -## Step 1: Choose a model - -To run inference on a model, you need a Gin config file that defines the model -params, and the model checkpoint to load from. For this example, a T5-1.1-Small -model fine-tuned on the -[`natural_questions_open_test`](https://github.com/google-research/google-research/tree/master/t5_closed_book_qa/t5_cbqa/tasks.py?l=141&rcl=370261021) -SeqIO Task will be used: - -+ Model checkpoint - - [`cbqa/small_ssm_nq/model.ckpt-1110000`](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/cbqa/small_ssm_nq/) -+ Model Gin file - - [`models/t5_1_1_small.gin`](https://github.com/google-research/t5x/blob/main/t5x/google/examples/flaxformer_t5/configs/models/t5_1_1_small.gin). - -If you would like to fine-tune your model before inference, please follow the -[fine-tuning](finetune.md) tutorial, and continue to Step 2. - -## Step 2: Choose a TF Example file source - -T5X supports running inference on `tfrecord`, `recordio` and `sstable` files -containing TF Examples. For the example run, you will run inference on -`tfrecord` files containing the `'natural_questions_open'` dataset located here: -`/path/to/tfds/data/dir/natural_questions_open/1.0.0/natural_questions_open-validation.tfrecord*`. -Here's an example of a single row of data from this file (you can explore this -file further using [GQUI](http://shortn/_oNuDhg7jwN)): - -```json -{ # (tensorflow.Example) size=101B - features: { # (tensorflow.Features) size=99B - feature: { # (tensorflow.Features.FeatureEntry) size=27B - key: "answer" # size=6 - value: { # (tensorflow.Feature) size=17B - bytes_list: { # (tensorflow.BytesList) size=15B - value: [ "Jason Flemyng" ] # size=13 - } # features.feature[0].value.bytes_list - } # features.feature[0].value - } # features.feature[0] - feature: { # (tensorflow.Features.FeatureEntry) size=68B - key: "question" # size=8 - value: { # (tensorflow.Feature) size=56B - bytes_list: { # (tensorflow.BytesList) size=54B - value: [ "who played hyde in league of extraordinary gentlemen" ] # size=52 - } # features.feature[1].value.bytes_list - } # features.feature[1].value - } # features.feature[1] - } # features -} -``` - -## Step 3: Write a Gin Config - -After choosing the model and file source for your run, the next step is to -configure your run using Gin. If you're not familiar with Gin, reading the -[T5X Gin Primer](gin.md) is recommended. T5X provides a Gin file that configures -the T5X inference job (located at -[`t5x/configs/runs/infer_from_tfexample_file.gin`](https://github.com/google-research/t5x/blob/main/t5x/configs/runs/infer_from_tfexample_file.gin)) -to run inference on TF Example files, and expects a few params from you. These -params can be specified in a separate Gin file, or via commandline flags. -Following are the required params: - -+ `CHECKPOINT_PATH`: This is the path to the model checkpoint (from Step 1). - For the example run, set this to - `'gs://t5-data/pretrained_models/cbqa/small_ssm_nq/model.ckpt-1110000'`. -+ `TF_EXAMPLE_FILE_PATHS`: This is a list of paths or glob patterns to read TF - Examples from. For the example run, set this to - `['/path/to/tfds/data/dir/natural_questions_open/1.0.0/natural_questions_open-validation.tfrecord*']`. -+ `TF_EXAMPLE_FILE_TYPE`: This is the TF Example file format. Currently - supported file formats are `tfrecord`, `recordio` and `sstable`. For the - example run, set this to `'tfrecord'`. -+ `FEATURE_LENGTHS`: This is a dict mapping feature key to maximum int length - for that feature. the TF Example features are truncated to the provided - value. For the example run, set this to `{'inputs': 38, 'targets': 18}`, - which is the maximum token length for the test set. -+ `INFER_OUTPUT_DIR`: A path to write inference outputs to. When launching - using XManager, this path is automatically set and can be accessed from the - XManager Artifacts page. When running locally using Blaze, you can - explicitly pass a directory using a flag. Launch commands are provided in - the next step. - -In addition to the above params, you may also need to override the -`create_task_from_tfexample_file.inputs_key` param based on the data format (it -is set to `'inputs'` by default. For the example run, the `'question'` key -contains the input (see Step 2), so add the following to your Gin config: - -```gin -create_task_from_tfexample_file.inputs_key = 'question' -``` - -Additionally, you will need to import the -[`infer_from_tfexample_file.gin`](https://github.com/google-research/t5x/blob/main/t5x/configs/runs/infer_from_tfexample_file.gin) -and the Gin file for the model, which for the example run is -[`t5_1_1_small.gin`](https://github.com/google-research/t5x/blob/main/t5x/google/examples/flaxformer_t5/configs/models/t5_1_1_small.gin). - -```gin -include 'runs/infer_from_tfexample_file.gin' -include 'models/t5_1_1_small.gin' -``` - -Note that the `include` statements use relative paths in this example. You will -pass an appropriate `gin_search_paths` flag to locate these files when launching -your run. Absolute paths to Gin files can also be used, e.g. - -```gin -include 't5x/configs/runs/infer_from_tfexample_file.gin' -include 't5x/google/examples/flaxformer_t5/configs/models/t5_1_1_small.gin' -``` - -Finally, your Gin file should look like this: - -```gin -include 'runs/infer_from_tfexample_file.gin' -include 'models/t5_1_1_small.gin' - -CHECKPOINT_PATH = 'gs://t5-data/pretrained_models/cbqa/small_ssm_nq/model.ckpt-1110000' -TF_EXAMPLE_FILE_PATHS = ['/path/to/tfds/data/dir/natural_questions_open/1.0.0/natural_questions_open-validation.tfrecord*'] -TF_EXAMPLE_FILE_TYPE = 'tfrecord' -FEATURE_LENGTHS = {'inputs': 38, 'targets': 18} -create_task_from_tfexample_file.inputs_key = 'question' -``` - -See -[`t5x/configs/examples/inference/t5_1_1_small_cbqa_natural_questions_tfexample.gin`](https://github.com/google-research/t5x/blob/main/t5x/google/examples/flaxformer_t5/configs/examples/inference/t5_1_1_small_cbqa_natural_questions_tfexample.gin) -for this example. Make sure that your Gin file is linked as a data dependency to -the T5X inference -[binary](https://github.com/google-research/t5x/blob/main/t5x/BUILD;l=74;rcl=398627055). If your -Gin file is not included, see the -[Advanced Topics section](#custom-t5x-binaries) at the end of this tutorial for -instructions to add it, or skip writing a Gin file and pass the above params as -flags when launching the inference job (see instructions in Step 4). - -## Step 4: Launch your experiment - -To launch your experiment locally (for debugging only; larger checkpoints may -cause issues), run the following on commandline: - -```sh -INFER_OUTPUT_DIR="/tmp/model-infer/" -python -m t5x.infer_unfragmented \ - --gin_file=t5x/google/examples/flaxformer_t5/configs/examples/inference/t5_1_1_small_cbqa_natural_questions_tfexample.gin \ - --gin.INFER_OUTPUT_DIR=\"${INFER_OUTPUT_DIR}\" \ - --alsologtostderr -``` - -Note that multiple comma-separated paths can be passed to the `gin_search_paths` -flag, and these paths should contain all Gin files used or included in your -experiment. - - -After inference has completed, you can view predictions in the `jsonl` files in -the output dir. JSON data is written in chunks and combined at the end of the -inference run. Refer to [Sharding](#sharding) and -[Checkpointing](#checkpointing) sections for more details. - -## Next Steps - -Now that you have successfully run inference on a model, here are some topics -you might want to explore next: - -+ [Fine-tuning a model.](finetune.md) -+ [Evaluating a model.](eval.md) -+ [Training a model from scratch.](pretrain.md) - -We also touch upon a few advanced topics related to inference below that might -be useful, especially when customizing your inference job. - -## Advanced Topics - -### Dataset Sharding {#sharding .no-toc} - -You can run inference in parallel across multiple TPU slices by setting the -`num_shards` flag when running using XManager. When `num_shards > 1`, the -dataset is interleaved among the shards and the predictions are combined in the -end; hence the order of examples in the data source and the predictions in the -output json files will not match (order is guaranteed to match for `num_shards = -1` or the number of input file shards). - -### Dataset Checkpointing {#checkpointing .no-toc} - -You can control dataset checkpointing frequency by overriding the -`infer.checkpoint_period` in -[runs/infer.gin](https://github.com/google-research/t5x/blob/main/t5x/configs/runs/infer.gin), -which is set to `100` by default. This means that the dataset is checkpointed -after running inferences on `checkpoint_period` batches (batches, not examples; -you can control batch size by overriding `utils.DatasetConfig.batch_size` in -[runs/infer.gin](https://github.com/google-research/t5x/blob/main/t5x/configs/runs/infer.gin), it -is set to `32` by default). - - -### Defining a custom SeqIO Task/Mixture to run inference on {.no-toc} - -Refer to [SeqIO documentation](https://github.com/google/seqio/blob/main/README.md). diff --git a/t5x-main/docs/usage/infer-seqio.md b/t5x-main/docs/usage/infer-seqio.md deleted file mode 100644 index 541d52deede39197db2a714161188e9b99852896..0000000000000000000000000000000000000000 --- a/t5x-main/docs/usage/infer-seqio.md +++ /dev/null @@ -1,241 +0,0 @@ -# Running inference on a Model - - -## Introduction - -This page outlines the steps to run inference a model with T5X on Tasks/Mixtures -defined with [SeqIO](https://github.com/google/seqio/blob/main/README.md). - -## Overview - -Running inference on a model with T5X using SeqIO Task/Mixtures consists of the -following steps: - -1. Choose the model to run inference on. -1. Choose the SeqIO Task/Mixture to run inference on. -1. Write a Gin file that configures the model, SeqIO Task/Mixture and other - details of your inference run. -1. Launch your experiment locally or on XManager. -1. Monitor your experiment and access predictions. - -These steps are explained in detail in the following sections. An example run -that runs inference on a fine-tuned T5-1.1-Small checkpoint on the -[(Open Domain) (Open Domain) Natural Questions benchmark](https://ai.google.com/research/NaturalQuestions/) -is also showcased. - -## Step 1: Choose a model - -To run inference on a model, you need a Gin config file that defines the model -params, and the model checkpoint to load from. For this example, a T5-1.1-Small -model fine-tuned on the -[`natural_questions_open_test`](https://github.com/google-research/google-research/tree/master/t5_closed_book_qa/t5_cbqa/tasks.py?l=141&rcl=370261021) -SeqIO Task will be used: - -+ Model checkpoint - - [`cbqa/small_ssm_nq/model.ckpt-1110000`](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/cbqa/small_ssm_nq/) -+ Model Gin file - - [`models/t5_1_1_small.gin`](https://github.com/google-research/t5x/blob/main/t5x/google/examples/flaxformer_t5/configs/models/t5_1_1_small.gin). - -If you would like to fine-tune your model before inference, please follow the -[fine-tuning](finetune.md) tutorial, and continue to Step 2. - -## Step 2: Choose a SeqIO Task/Mixture - -A SeqIO Task encapsulates the data source, the preprocessing logic to be -performed on the data before querying the model, the postprocessing logic to be -performed on model outputs, and the metrics to be computed given the -postprocessed outputs and targets (for inference, post-processing and metrics -are irrelevant). A SeqIO Mixture denotes a collection of Tasks and enables -fine-tuning a model on multiple Tasks. - -Many common datasets and benchmarks, e.g. [GLUE](https://gluebenchmark.com/), -[SuperGLUE](https://super.gluebenchmark.com/), -[WMT](https://www.tensorflow.org/datasets/catalog/wmt_t2t_translate), -[SQUAD](https://rajpurkar.github.io/SQuAD-explorer/), -[CNN/Daily Mail](https://github.com/abisee/cnn-dailymail), etc. have been -implemented as SeqIO Tasks/Mixtures and can be used directly. These -Tasks/Mixtures are defined in -[`third_party/py/t5/data/tasks.py`](https://github.com/google-research/text-to-text-transfer-transformer/tree/main/t5/data/tasks.py) -and -[`third_party/py/t5/data/mixtures.py`](https://github.com/google-research/text-to-text-transfer-transformer/tree/main/t5/data/mixtures.py). - -For the example run, you will run inference on the (Open Domain) Natural -Questions benchmark, which has been implemented as the `natural_questions_open` -Task in -[`/third_party/google_research/google_research/t5_closed_book_qa/t5_cbqa/tasks.py`](https://github.com/google-research/google-research/tree/master/t5_closed_book_qa/t5_cbqa/tasks.py?l=98&rcl=370261021). -Here's an example of a single row of preprocessed data from this Task: - -```json -{ - 'inputs_pretokenized': 'nq question: what was the main motive of salt march', - 'inputs': [3, 29, 1824, 822, 10, 125, 47, 8, 711, 10280, 13, 3136, 10556, 1] - 'targets_pretokenized': 'challenge to British authority', - 'targets': [1921, 12, 2390, 5015, 1], - 'answers': ['challenge to British authority'] -} -``` - -## Step 3: Write a Gin Config - -After choosing the model and SeqIO Task/Mixture for your run, the next step is -to configure your run using Gin. If you're not familiar with Gin, reading the -[T5X Gin Primer](gin.md) is recommended. T5X provides a Gin file that configures -the T5X inference job (located at -[`runs/infer.gin`](https://github.com/google-research/t5x/blob/main/t5x/configs/runs/infer.gin)) to -run inference on SeqIO Task/Mixtures, and expects a few params from you. These -params can be specified in a separate Gin file, or via commandline flags. -Following are the required params: - -+ `CHECKPOINT_PATH`: This is the path to the model checkpoint (from Step 1). - For the example run, set this to - `'gs://t5-data/pretrained_models/cbqa/small_ssm_nq/model.ckpt-1110000'`. -+ `MIXTURE_OR_TASK_NAME`: This is the SeqIO Task or Mixture name to run - inference on (from Step 2). For the example run, set this to - `'natural_questions_open'`. -+ `MIXTURE_OR_TASK_MODULE`: This is the Python module that contains the SeqIO - Task or Mixture. For the example run, set this to - `'google_research.t5_closed_book_qa.t5_cbqa.tasks'`. - Note that this module must be included as a dependency in the T5X inference - [binary](https://github.com/google-research/t5x/blob/main/t5x/BUILD;l=74;rcl=398627055). Most - common Task modules, including `t5_closed_book_qa`, are already included. If - your module is not included, see the - [Advanced Topics section](#custom-t5x-binaries) at the end of this tutorial - for instructions to add it. -+ `TASK_FEATURE_LENGTHS`: This is a dict mapping feature key to maximum length - for that feature. After preprocessing, features are truncated to the - provided value. For the example run, set this to `{'inputs': 38, 'targets': - 18}`, which is the maximum token length for the test set. -+ `INFER_OUTPUT_DIR`: A path to write inference outputs to. When launching - using XManager, this path is automatically set and can be accessed from the - XManager Artifacts page. When running locally using Blaze, you can - explicitly pass a directory using a flag. Launch commands are provided in - the next step. - -In addition to the above params, you will need to import -[`infer.gin`](https://github.com/google-research/t5x/blob/main/t5x/configs/runs/infer.gin) and the -Gin file for the model, which for the example run is -[`t5_1_1_small.gin`](https://github.com/google-research/t5x/blob/main/t5x/google/examples/flaxformer_t5/configs/models/t5_1_1_small.gin). - -```gin -include 'runs/infer.gin' -include 'models/t5_small.gin' -``` - -Note that the `include` statements use relative paths in this example. You will -pass an appropriate `gin_search_paths` flag to locate these files when launching -your run. Absolute paths to Gin files can also be used, e.g. - -```gin -include 't5x/configs/runs/infer.gin' -include 't5x/google/examples/flaxformer_t5/configs/models/t5_1_1_small.gin' -``` - -Finally, your Gin file should look like this: - -```gin -include 'runs/infer.gin' -include 'models/t5_1_1_small.gin' - -CHECKPOINT_PATH = 'gs://t5-data/pretrained_models/cbqa/small_ssm_nq/model.ckpt-1110000' -MIXTURE_OR_TASK_NAME = 'closed_book_qa' -MIXTURE_OR_TASK_MODULE = 'google_research.t5_closed_book_qa.t5_cbqa.tasks' -TASK_FEATURE_LENGTHS = {'inputs': 38, 'targets': 18} -``` - -See -[`t5_1_1_small_cbqa_natural_questions.gin`](https://github.com/google-research/t5x/blob/main/t5x/google/examples/flaxformer_t5/configs/examples/inference/t5_1_1_small_cbqa_natural_questions.gin) -for this example. Make sure that your Gin file is linked as a data dependency to -the T5X inference -[binary](https://github.com/google-research/t5x/blob/main/t5x/BUILD;l=74;rcl=398627055). If your -Gin file is not included, see the -[Advanced Topics section](#custom-t5x-binaries) at the end of this tutorial for -instructions to add it, or skip writing a Gin file and pass the above params as -flags when launching the inference job (see instructions in Step 4). - -## Step 4: Launch your experiment - -To launch your experiment locally (for debugging only; larger checkpoints may -cause issues), run the following on commandline: - -```sh -INFER_OUTPUT_DIR="/tmp/model-infer/" -python -m t5x.infer_unfragmented \ - --gin_file=t5x/google/examples/flaxformer_t5/configs/examples/inference/t5_1_1_small_cbqa_natural_questions.gin \ - --gin.INFER_OUTPUT_DIR=\"${INFER_OUTPUT_DIR}\" \ - --alsologtostderr -``` - -Note that multiple comma-separated paths can be passed to the `gin_search_paths` -flag, and these paths should contain all Gin files used or included in your -experiment. - - -## Step 5: Monitor your experiment and parse results - - -After inference has completed, you can view predictions in the `jsonl` files in -the output dir. JSON data is written in chunks and combined at the end of the -inference run. Refer to [Sharding](#sharding) and -[Checkpointing](#checkpointing) sections for more details. - -## Next Steps - -Now that you have successfully run inference on a model, here are some topics -you might want to explore next: - -+ [Fine-tuning a model.](finetune) -+ [Evaluating a model.](eval) -+ [Training a model from scratch.](pretrain) - -We also touch upon a few advanced topics related to inference below that might -be useful, especially when customizing your inference job. - -## Advanced Topics - -### Dataset Sharding {#sharding .no-toc} - -You can run inference in parallel across multiple TPU slices by setting the -`num_shards` flag when running using XManager. When `num_shards > 1`, the -dataset is interleaved among the shards and the predictions are combined in the -end; hence the order of examples in the data source and the predictions in the -output json files will not match (order is guaranteed to match for `num_shards = -1` or the number of input file shards). - -### Dataset Checkpointing {#checkpointing .no-toc} - -You can control dataset checkpointing frequency by overriding the -`infer.checkpoint_period` in -[runs/infer.gin](https://github.com/google-research/t5x/blob/main/t5x/configs/runs/infer.gin), -which is set to `100` by default. This means that the dataset is checkpointed -after running inferences on `checkpoint_period` batches (batches, not examples; -you can control batch size by overriding `utils.DatasetConfig.batch_size` in -[runs/infer.gin](https://github.com/google-research/t5x/blob/main/t5x/configs/runs/infer.gin), it -is set to `32` by default). - -### Changing Length and Decoding Strategy {#decoding-strategies .no-toc} - -By default, T5X does inference using an arg-max decoding strategy, always -picking the most likely next token. To use random sampling instead, you may -change any of the following parameters in your gin config: - -```gin -decoding.temperature_sample: - temperature = 1.0 - topk = 1 - topp = 0.0 -``` - -You can also control the number of tokens which get generated by specifying: - -```gin -decoding.temperature_sample: - max_decode_steps = 50 -``` - -More detailed documentation on defining a decoding stategy can be found -[here](https://github.com/google-research/t5x/blob/main/docs/usage.md/decoding). - - -### Defining a custom SeqIO Task/Mixture to run inference on {.no-toc} - -Refer to [SeqIO documentation](https://github.com/google/seqio/blob/main/README.md). diff --git a/t5x-main/docs/usage/infer.md b/t5x-main/docs/usage/infer.md deleted file mode 100644 index a2f8c6ea45b95ea65a046796aa693f8bdeceabbc..0000000000000000000000000000000000000000 --- a/t5x-main/docs/usage/infer.md +++ /dev/null @@ -1,16 +0,0 @@ -# Running Inference on a Model - - -This page outlines the steps to run inference a model with T5X. - -Refer to this tutorial when you have an existing model that you want to run -inference on. If you would like to fine-tune your model before inference, please -refer to the [fine-tuning](finetune.md) tutorial. If you'd like to compute -evaluation metrics for your model, please refer to the [evaluation](eval.md) -tutorial. You can also run evals as part of your fine-tuning run. - -T5X supports a few inference modes. Please refer to the appropriate tutorial -based on your use-case: - -1. Run inference on [SeqIO Tasks/Mixtures](infer-seqio.md) -1. Run inference on [TF Example files](infer-files.md) diff --git a/t5x-main/docs/usage/metrics.md b/t5x-main/docs/usage/metrics.md deleted file mode 100644 index 2360f7162f8a239b0854904183dc368465ab559b..0000000000000000000000000000000000000000 --- a/t5x-main/docs/usage/metrics.md +++ /dev/null @@ -1,266 +0,0 @@ -# Metrics Overview - - -## Introduction - -T5X provides a flexible and customizable library for managing metrics. Metrics -in T5X rely on [CLU](https://github.com/google/CommonLoopUtils/blob/main/README.md), which broadly provides utilities for -writing training loops but specifically provides metric libraries that are -extended by T5X. - - -NOTE: This document currently only applies to train and 'train_eval' metrics, -not to 'infer_eval' metrics, which are implemented using SeqIO. We plan to unify -these three in the future. - -## Metrics and Writers - -CLU provides `Metric` and `MetricWriter` classes. Full details are provided in -[go/clu-metrics](https://github.com/google/CommonLoopUtils/blob/main/README.md-metrics), but a simplified summary will suffice -for our purposes. - -[`clu.metrics.Metric`](https://github.com/google/CommonLoopUtils/tree/main/clu/metrics.py?q=symbol:%5CbMetric%5Cb) -provides an abstract interface for metrics. The interface can be simply -represented by the following: - -```py -class Metric: - - @classmethod - def from_model_output(cls, *args, **kwargs) -> Metric: - # creates a Metric from model output (i.e. loss arrays). - pass - - def merge(self, other) -> Metric: - # combines a Metric from the current step with that of a previous step. - pass - - def compute(self) -> Union[float, np.ndarray]: - # computes the writable value of a metric (as a float, array, etc.) - pass - - def compute_value(self) -> clu.values.Value: - # computes metric as a writable type (Scalar, Image, Histogram, etc.) - # defaults to Scalar - return clu.values.Scalar(self.compute()) -``` - -`Metric`s can then be extended into concrete representations, such as Sum: - -```py -@flax.struct.dataclass -class Sum(Metric): - - total: float - - @classmethod - def from_model_output(cls, values: np.ndarray) -> Metric: - return cls(total=np.sum(values)) - - def merge(self, other) -> Metric: - return type(self)(total = self.total + other.total) - - def compute(self) -> Union[float, array]: - return self.total - - # If Metric is non-Scalar, return a different Value type as needed. - def compute_value() -> clu.values.Value: - return clu.values.Scalar(self.compute()) -``` - -We will elaborate in more detail [below](#a-metric-example) on how Metrics are -practically used in T5X. - -In addition to CLU provided metrics like Average and Accuracy, T5X provides a -few specialized metrics, like TimeRate and AveragePerStep. A full list of CLU -metrics is provided at -[clu/metrics.py](https://github.com/google/CommonLoopUtils/tree/main/clu/metrics.py) while T5X metrics -are listed in [t5x/metrics.py](https://github.com/google-research/t5x/blob/main/t5x/metrics.py). We -will elaborate on specialized metrics like TimeRate and AveragePerStep -[below](#special-t5x-metrics). - -Given a constructed `Metric` object, we can use a `MetricWriter` to write it in -a readable form to some destination. - -`MetricWriter` again has a simple interface, represented by the following - -```py -class MetricWriter: - - def write_scalars(self, step: int, scalars: Mapping[str, Scalar]): - pass - - def write_images(self, step: int, images: Mapping[str, Array]): - pass - - ... -``` - -A `MetricWriter` implements a specific write method for each type, including -scalars, images, audios, texts, histograms, and hyperparameters. - -CLU provides a convenience method for easily writing metric values of diverse -types, -[`clu.metric_writers.write_values`](https://github.com/google/CommonLoopUtils/tree/main/clu/metric_writers/utils.py?q=symbol:%5Cbwrite_values%5Cb). - -``` -def write_values(writer: MetricWriter, step: int, - metrics: Mapping[str, Union[values.Value, values.ArrayType, - values.ScalarType]]): -``` - -Given a mapping of string to -[`clu.values.Value`](https://github.com/google/CommonLoopUtils/tree/main/clu/values.py?q=symbol:%5CbValue%5Cb), -the method automatically calls the writer's appropriate write method. Such a -mapping can be easily obtained by calling `metric.compute_value()`. - -`MetricWriter` is subclassed by several specific writers, which enable writing -to the console, TF summary files, XManager, and others. See -[source](https://github.com/google/CommonLoopUtils/tree/main/clu/metric_writers/) for full details. By -default, the T5X -[`MetricsManager`](https://github.com/google-research/t5x/blob/main/t5x/trainer.py?q=symbol:%5CbMetricsManager%5Cb) -logs metrics to -[TensorBoard](https://github.com/google/CommonLoopUtils/tree/main/clu/metric_writers/summary_writer.py), -[XManager](https://github.com/google/CommonLoopUtils/tree/main/clu/metric_writers/google/xm_measurement_writer.py), -and -[INFO logs](https://github.com/google/CommonLoopUtils/tree/main/clu/metric_writers/logging_writer.py). -In the future, the set of writers used will be made more easily customizable. - -## Usage in T5X - -In a T5X Model, we have a `loss_fn` that returns a dictionary of metrics, -mapping string name to `Metric` objects. In the simplest case, this may involve -creating a dictionary such as the following: - -```py -metrics = { - 'nonpadding_tokens_fraction': Average(mask.sum(), count=mask.size()), - 'accuracy': Accuracy.from_model_output( - logits=logits, labels=targets.astype(jnp.int32), mask=mask) -} -``` - -`Metric` objects can either be intialized directly or by using -`from_model_output`. - -The metrics created on one particular training step (one call of the loss -function) are accumulated over subsequent steps (using the `merge` method). - -NOTE: Unlike in previous versions of T5X, "initial metrics" should not be -defined, since the first metrics returned from `loss_fn` are treated as the -initial metrics for later accumulation. - -Finally, in order to summarize the metrics into writable forms, we can simply -use the following: - -```py -summary = {k: m.compute() for k, m in metrics.items()} -``` - -Typically, the above call will not be necessary, since the T5X `BaseModel` -already includes it automatically. - -### A Metric Example - -Let's imagine that we want to create a metric that tracks the loss per number of -tokens. One (bad) way of doing this would be the following: - -```py {.bad} -# create metrics and return from loss_fn -metrics = { - 'loss': Sum(total=jnp.sum(loss)) - 'num_tokens': Sum(total=num_tokens) -} - -# run for many steps, metrics get merged and accumulated - -# summarize metrics -summary = { - 'loss_per_all_target_tokens': - metrics['loss'].compute() / metrics['num_tokens'].compute() -} -``` - -If this looks familiar, then you may be used to the old way of handling metrics -in T5X. This is obviously less than ideal, since we track two "metrics" that -we're not interested in, which we use to compute the actual one metric we want. - -A better way of implementing this could be more like this: - -```py -# create metrics and return from loss_fn -metrics = { - 'loss_per_all_target_tokens': Average(total=jnp.sum(loss), count=num_tokens) -} - -# run for many steps, metrics get merged and accumulated - -# summarize metrics -summary = {k: m.compute() for k, m in metrics.items()} -``` - -There are a few advantages of this change. First, we don't need to implement any -new logic in the summarization step - we can simply reuse the generic logic. -Second, our metric, `loss_per_all_target_tokens`, is explicitly created as an -`Average`, and is tracked throughout training, with no extraneous intermediate -metrics. - -NOTE: If you see live code like the former example, it is part of our ongoing -migration towards new-style metrics in T5X. Please help us clean it up! - -### Special T5X Metrics - -A few metrics are somewhat more complicated to use, largely due to limitations -of the T5X training library. Metrics can be found at -[`t5x/metrics.py`](https://github.com/google-research/t5x/blob/main/t5x/metrics.py). - -#### `AveragePerStep` - -When dealing with per-step metrics, use `AveragePerStep`. This could correspond -to metrics such as loss per step. It cannot be implemented simply using a -standard `Average` metric because the loss function, where the metric is -initially computed, may be run multiple times if we have multiple microbatches. -If we have two microbatches, this results in the metric being initialized twice -per step. Thus, we defer setting number of steps at creation time and set it -before the metrics are summarized. - -For example, we need to initialize `z_loss` and `steps_per_second` as follows: - -```py -'z_loss': AveragePerStep.from_model_output(z_loss) -``` - -Then, before summarization -[`set_step_metrics_num_steps(metrics, num_steps)`](https://github.com/google-research/t5x/blob/main/t5x/metrics.py;l=222) -is called automatically to set the number of steps for relevant metrics. - -#### `TimeRate` - -Another special metric is `TimeRate`, which is used to measure metrics over a -period of time. Our complication here is that the start time of the metric -cannot be set when the metric is created, since creation happens inside a JAX -compiled function. Instead, we must set the duration on the host. - -For example, we can initialize a `seqs_per_second` metric as follows: - -```py -'timing/seqs_per_second': TimeRate(numerator=num_examples) -``` - -Before summarization, -[`set_time_rate_metrics_duration(metrics, duration)`](https://github.com/google-research/t5x/blob/main/t5x/metrics.py;l=209) -is called automatically called to set the duration of time-related metrics. - -#### `StepsPerTime` - -This metric represents the sythesis of the above two, which can represent a -metric such as `steps_per_second`. - -```py -'timing/steps_per_second': StepsPerTime() -``` - -NOTE: Unless you are also overriding -[Trainer](https://github.com/google-research/t5x/blob/main/t5x/trainer.py;l=314), you likely only -need to worry about initializing metrics correctly, and not about making later -adjustments for duration and number of microbatches. diff --git a/t5x-main/docs/usage/partitioning.md b/t5x-main/docs/usage/partitioning.md deleted file mode 100644 index 2a62c76d11ee70a5a3042144c45c0f83c6b8ba4e..0000000000000000000000000000000000000000 --- a/t5x-main/docs/usage/partitioning.md +++ /dev/null @@ -1,429 +0,0 @@ -# Data, Model, and Activation Partitioning - - -TL;DR: The recommended way of specifying partitions in T5X. - -**Partitioning** is the process dividing and replicating machine learning *model -parameters*, *activations*, and *data* across the accelerator devices (TPU/GPU) -in order to: - -* Train and infer from models too large to fit in the memory of a single - device -* Use extremely large batch sizes -* Train faster - -## How to Partition - -Partitioning in T5X is configured in two steps: - -1. Specify logical axes names for parameter and activation array dimensions -2. Map the logical names to the physical axes of the accelerator mesh - -Let's take a closer look at each of these steps. - -**Note:** In T5X, partitioning is primarily provided through the -[jax.pjit][pjit] backend via `PjitPartitioner` using the Gin configuration -framework. - -### Specify logical axis names - -**Logical axis names** are a user-configured shorthand for grouping *axes* (aka -*dimensions*) of either parameter or activation arrays in a model -implementation. - -For example, you could refer to the axes of the inputs to a model as `('batch', -'length', 'vocab')`. If the parameters of the embedding matrix are labelled -`('vocab', 'embed')` then the activations following embedding should be named -`('batch', 'length', 'embed')`. - -**Description** | **Logical Axis Names** --------------------- | ------------------------------ -Inputs to model | `('batch', 'length', 'vocab')` -Embedding parameters | `('vocab', 'embed')` -Activations | `('batch', 'length', 'embed')` - -**How to configure logical axis names** - -Logical axis annotations can be provided through the utilities in -[`flax.linen.partitioning`][lan]. - -Instead of calling `self.param` to create parameters within your model -implementation, use the `flax.linen.partitioning.param_with_axes` API to -communicate axis names for each parameter. - -```py -from flax.linen import partitioning - -scale = partitioning.param_with_axes( - 'scale', self.scale_init, (features,), jnp.float32, axes=('embed',)) -``` - -For an example in context, see [`layers.py`][param_with_axes]. - -Tip: We recommend you use the *canonical* logical axis names listed -[below](#canonical-logical-axis-names). - -To specify the logical axes for *activation partitioning*, provide the logical -axes names to `flax.linen.partitioning.with_sharding_constraint` (instead of -using `jax.pjit.with_sharding_constraint` or -`t5x.partitioning.with_sharding_constraint`). - -```py -from flax.linen import partitioning - -... -output = jnp.dot(x, embedding) -output = with_sharding_constraint(output, ('batch', 'length', 'embed')) -return output -``` - -### Map logical names to device - -For `jax.pjit` to know how to partition these arrays across the hardware, the -logical axis names must be mapped to the physical axes of the accelerator mesh. - -**Note:** A *mesh* is an n-dimensional array of TPU (or GPU) processors, -connected by a network. The TPUv3 processor is limited to 2D meshes. The TPUv4 -processor can handle 3D meshes. - -In T5X, the two primary *hardware* axes are named `'data'` and `'model'`, -referring to the default mappings for data- and model-parallelism. - -> **Note:** You are actually free to map model parameters or activations across -> the `'data'` axis. In fact, this is what is done in 2D parameter/activation -> sharding. To see how this works in practice, see: -> -> * [The example mappings](#example-configurations) below -> * [`t5x.partitioning.standard_logical_axis_rules`][standard-rules] -> implementation - - -#### Configuring `PjitPartitioner` - -`PjitPartitioner` has three primary constructor arguments: - -* `model_parallel_submesh` -* `num_partitions` -* `logical_axis_rules` - -The `model_parallel_submesh` and `num_partitions` arguments provide two -mutually-exclusive methods of specifying the submesh of devices to use for model -partitioning. As a rule of thumb: - -* Use `model_parallel_submesh` when you want to specify how the logical names - are mapped to the device -* Use`num_partitions` for an automatic mapping - -**Using `model_parallel_submesh`** - -The `PjitPartitioner` constructor argument that provides the most control is: - -``` -model_parallel_submesh(Tuple[int, int, int, int]) -``` - -It is a 4-tuple that specifies the `(x, y, z, c)` *model-parallel* submesh–an -axis of accelerator parallelism orthogonal to data parallelism. Axes in a -model's parameter or activation arrays can be sharded over this submesh using -axis rules that map them to `'model'`. - -**Note:** The effective number of model subpartitions is equal to -`np.prod(model_parallel_submesh)` and must evenly divide the total number of -devices. Specifically: \ -`jax.device_count() % np.prod(model_parallel_submesh) == 0`. - -The rest of the TPU mesh is the *data parallel* submesh, providing -`jax.device_count() // np.prod(model_parallel_submesh)` partitions. It is used -for data (aka *batch*) parallelism and to shard other array axes that are mapped -to `'data'`. - -**Using `num_partitions`** - -Alternatively, - -``` -num_partitions(int) -``` - -accepts an integer that specifies the size of the model parallel submesh to be -*automatically* selected for the current topology. - -**Using `logical_axis_rules`** - -The third key argument is - -``` -logical_axis_rules(Sequence[Tuple[str, Optional[str]]]) -``` - -This argument accepts a priority-ordered sequence of key-value (KV) tuples. -These tuples map the logical axis names to hardware resources, using `'model'` -and `'data'` as the two primary hardware axes. Specifically, each logical axis -can be mapped to one of: - -* `None` to disable sharding, and thus be fully-replicated -* `'model'` to shard across the model-parallel submesh -* `'data'` to shard across the data-parallel submesh - -The same key can be mapped to multiple values. For each array, mappings are -applied in priority order. If a hardware resource has already been assigned in -to a different axis and multiple keys exist, a latter mapping may be used. - -For example, consider the following set of logical axis rules: - -```py -[ - ('head', 'model'), - ('embed', 'model'), - ('embed', 'data'), - ('vocab', 'model'), -] -``` - -For an array with logical axes `('embed', 'head')`, `'head'` will first be -mapped to `'model'`, since it comes first in the priority list. Next, `'embed'` -will be mapped to `'data'`, since `'model'` has already been used. However, an -array with logical axes `('vocab', 'embed')` will receive the mapping `(None, -'model')` since `'embed'` has a higher priority than `'vocab'`. - -T5X provides the `t5x.partitioning.standard_logical_axis_rules()` function to -generate canonical logical axis rule sets depending on how many mesh dimensions -you wish to shard. This assumes that you are using -[canonical logical axis names](#canonical-logical-axis-names). - -For details, see -[`t5x.partitioning.standard_logical_axis_rules()`][standard-rules]. - -## Other Stuff - -### Overriding axis names from an external codebase - -You may wish to incorporate Flax modules from an external codebase into your -model implementation that uses `self.param` instead of -`flax.linen.partitioning.param_with_axes`, or that may use axis names that are -incompatible with your codebase. - -To deal with this situation, we provide the `utils.override_params_axes_names` -helper function. This helper can be called at the end of -`Model.get_initial_variables` to apply a priority-ordered mapping from regex -patterns (fully matching parameter names) to tuples containing string logical -axis names to replace model-derived names. - -For example, the following configuration provides logical axis names for an -external module called 'external_mlp' used in every layer of the model's -encoder, without modifying any other modules: - -```py -class MyCustomEncoderDecoderModel(models.EncoderDecoderModel): - - def get_initial_variables( - self, - rng: jnp.ndarray, - input_shapes: Mapping[str, Array], - input_types: Optional[Mapping[str, jnp.dtype]] = None - ) -> flax_scope.FrozenVariableDict: - initial_variables = super().get_initial_variables( - rng=rng, input_shapes=input_shapes, input_types=input_types) - return utils.override_params_axes_names( - initial_variables, - params_axes_names_override=[ - ('encoder/layer_\\d/external_mlp/kernel':, ('embed', 'mlp')), - ('encoder/layer_\\d/external_mlp/bias':, ('mlp',)), - ]) -``` - -**Note:** It is not possible to add or modify activation partitioning in an -external module. - -### Canonical logical axis names - -Use the following logical axis names to be compatible with -[`t5x.partitioning.standard_logical_axis_rules`][standard-rules]: - -| Logical Axis Name | Description | -| -------------------- | ---------------------------------------------------- | -| `"embed"` | The common "activation_dim" in the network, first | -: : emitted by the embedding layer. : -| `"heads"` | Number of heads for attention/relative position | -: : biases. : -| `"kv"` | For query/key/value hidden dimensions of each head. | -| `"joined_kv"` | For "heads * kv" fused dimension of attention | -: : matrices, when the kernel is reshaped such that : -: : "heads" and "kv" are packed in the same dimension. : -| `"mlp"` | Intermediate dimension of the feed-forward layer. | -| `"vocab"` | For embeddings, the input/output vocabulary size. | -| `"mlp_activations"` | For fused MLP matrices that have a dimension for the | -: : activation function index. : -| `"stack"` | For KV and QKV fused attention implementations, the | -: : manual parameter-fusion stacked dimension. : -| `"abspos_buckets"` / | The dimension for positional bias buckets. | -: `"relpos_buckets"` : : - -If you wish to use a non-canonical axis name, you will need to pass a custom set -of axis rules to the `PjitPartitioner`. - --------------------------------------------------------------------------------- - -## Example configurations - -### Automatic - Full 2D partitioning - -You can override the default 1D sharding configuration by modifying the -arguments to [`t5x.partitioning.standard_logical_axis_rules`][standard-rules]. -For example, for full parameter and activation 2D partitioning you can set: - -```py -from t5x import partitioning - -train_script.train: - partitioner = @partitioning.PjitPartitioner() - -partitioning.PjitPartitioner: - num_partitions = 1 - logical_axis_rules= @partitioning.standard_logical_axis_rules() - -partitioning.standard_logical_axis_rules: - activation_partitioning_dims = 2 - parameter_partitioning_dims = 2 -``` - -### Manual configurations - -Alternatively, you can manually set the rules, experimenting with some of the -following options: - -* [Data-only parallelism](#data-only-parallelism) -* [Data parallel with parameter gather](#data-parallel-with-parameter-gather) -* [Data and model parallel with replicated activations](#data-and-model-parallel-with-replicated-activations) -* [Data and model parallel with sharded activations](#data-and-model-parallel-with-sharded-activations) -* [Full 2D sharding](#full-2d-sharding) - -#### Data-only parallelism - -```py -partitioning.PjitPartitioner.logical_axis_rules = [ - ('batch', 'data'), - ('vocab', None), - ('embed', None), - ('mlp', None), - ('heads', None), - ('kv', None), - ('joined_kv', None), - ('relpos_buckets', None), - ('abspos_buckets', None), - ('length', None), - ('layers', None), - ('stack', None), - ('mlp_activations', None), -] -``` - -#### Data parallel with parameter gather - -An example of 2D parameter partitioning with trival MP submesh, such as -[ZeRO-3][ZeRO-3]. - -```py -partitioning.PjitPartitioner.logical_axis_rules = [ - ('batch', 'data'), - # all weight matrices have this axis; activations already shard it along 'data' - ('embed', 'data'), - ('vocab', None), - ('mlp', None), - ('heads', None), - ('kv', None), - ('joined_kv', None), - ('relpos_buckets', None), - ('abspos_buckets', None), - ('length', None), - ('layers', None), - ('stack', None), - ('mlp_activations', None), -] -``` - -#### Data and model parallel with replicated activations - -An example of 1D parameter partitioning, such as [Megatron][megatron]. - -```py -partitioning.PjitPartitioner.logical_axis_rules = [ - ('batch', 'data'), - ('mlp', 'model'), - ('heads', 'model'), - ('vocab', 'model'), - ('embed', None), - ('kv', None), - ('joined_kv', None), - ('relpos_buckets', None), - ('abspos_buckets', None), - ('length', None), - ('layers', None), - ('stack', None), - ('mlp_activations', None), -] -``` - -#### Data and model parallel with sharded activations - -An example of 1D parameter partitioning with 2D activation partitioning, such as -[Optimus][optimus]. - -```py -partitioning.PjitPartitioner.logical_axis_rules = [ - ('batch', 'data'), - ('mlp', 'model'), - ('heads', 'model'), - ('vocab', 'model'), - # shard remaining activations; weight matrices already have axes mapped to 'model' - ('embed', 'model'), - ('kv', None), - ('joined_kv', None), - ('relpos_buckets', None), - ('abspos_buckets', None), - ('length', None), - ('layers', None), - ('stack', None), - ('mlp_activations', None), -] -``` - -#### Full 2D sharding - -An example of 2D parameter and activation partitioning, such as -[GShard][gshard]. - -```py -partitioning.PjitPartitioner.logical_axis_rules = [ - ('batch', 'data'), - ('mlp', 'model'), - ('heads', 'model'), - ('vocab', 'model'), - # shard both activations and weight matrices on the remaining available axis - ('embed', 'model'), - ('embed', 'data'), - ('kv', None), - ('joined_kv', None), - ('relpos_buckets', None), - ('abspos_buckets', None), - ('length', None), - ('layers', None), - ('stack', None), - ('mlp_activations', None), -] -``` - - - - - -[ZeRO-3]: https://arxiv.org/abs/1910.02054 -[gshard]: https://arxiv.org/abs/2105.04663 -[flaxformer]: https://github.com/google/flaxformer/tree/main/flaxformer/architectures/t5/ -[lan]: https://github.com/google/flax/tree/main/flax/linen/partitioning.py -[megatron]: https://arxiv.org/abs/1909.08053 -[minimal]: https://github.com/google-research/t5x/blob/main/t5x/examples/ -[optimus]: https://arxiv.org/abs/2104.05343 -[param_with_axes]: https://github.com/google-research/t5x/blob/main/t5x/examples/t5/layers.py;rcl=427300354;l=462 -[pjit]: https://github.com/google/jax/tree/main/jax/experimental/pjit.py -[standard-rules]: https://github.com/google-research/t5x/blob/main/t5x/partitioning.py?l=438&rcl=421294093 diff --git a/t5x-main/docs/usage/pretrain.md b/t5x-main/docs/usage/pretrain.md deleted file mode 100644 index 7cc37f07fbce3fcb3b071018e2b48f1c3114ab9c..0000000000000000000000000000000000000000 --- a/t5x-main/docs/usage/pretrain.md +++ /dev/null @@ -1,213 +0,0 @@ -# Pretraining a model - - -## Introduction - -This page outlines the steps to pretrain a model with T5X on common tasks -defined with [SeqIO](https://github.com/google/seqio/blob/main/README.md). - -## Overview - -Pretraining a model with T5X consists of the following steps: - -1. Choose the model architecture. -2. Choose the SeqIO Task/Mixture to for training. -3. Write a Gin file that configures the model, SeqIO Task/Mixture and other - details of your pretraining run. -4. Launch your experiment locally or on XManager. -5. Monitor your experiment. - -These steps are explained in detail in the following sections. An example run -that trains a T5 1.1 Small checkpoint from scratch on the C4 dataset using the -span corruption pretraining objective is also showcased. - -## Step 1: Choose a model architecture - -To train a model, you need a Gin config file that defines the model params. For -your convenience, Gin configs for common models have been made available for use -in T5X. A list of all the available pre-trained models (with model checkpoints -and Gin config files) are available in the [Models](https://github.com/google-research/t5x/blob/main/docs/models.md) -documentation. - -For the example run, you will use the T5 1.1 Small model. The Gin file for this -model is located at -[`/t5x/examples/t5/t5_1_1/1_1_small.gin`](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_1/small.gin). - -## Step 2: Choose a SeqIO Task/Mixture - -A SeqIO Task encapsulates the data source, the preprocessing logic to be -performed on the data before querying the model, the postprocessing logic to be -performed on model outputs, and the metrics to be computed given the -postprocessed outputs and targets. A SeqIO Mixture denotes a collection of Tasks -and enables pretraining a model on multiple Tasks simultaneously. - -Many common datasets and benchmarks, e.g. [GLUE](https://gluebenchmark.com/), -[SuperGLUE](https://super.gluebenchmark.com/), -[WMT](https://www.tensorflow.org/datasets/catalog/wmt_t2t_translate), -[SQUAD](https://rajpurkar.github.io/SQuAD-explorer/), -[CNN/Daily Mail](https://github.com/abisee/cnn-dailymail), etc. have been -implemented as SeqIO Tasks/Mixtures and can be used directly. These -Tasks/Mixtures are defined in -[`third_party/py/t5/data/tasks.py`](https://github.com/google-research/text-to-text-transfer-transformer/tree/main/t5/data/tasks.py) -and -[`third_party/py/t5/data/mixtures.py`](https://github.com/google-research/text-to-text-transfer-transformer/tree/main/t5/data/mixtures.py). - -For the example run, you will train the model on -[`c4_v220_span_corruption`](https://github.com/google-research/text-to-text-transfer-transformer/tree/main/t5/data/tasks.py?l=42&rcl=370153959) -Task that implements the span corruption pretraining objective using the C4 -dataset. This is the final pretraining Task used in the -[T5 paper](https://arxiv.org/pdf/1910.10683.pdf%C3%82%C2%A0). - -TIP: Want to use a custom Task or Mixture? See section below called "Adding -SeqIO Task/Mixture modules and Gin files" - -## Step 3: Write a Gin Config - -After choosing the model architecture and SeqIO Task/Mixture for your run, the -next step is to configure your run using Gin. If you're not familiar with Gin, -reading the [T5X Gin Primer](gin.md) is recommended. - -T5X provides a Gin file that configures the T5X trainer for pretraining (located -at -[`runs/pretrain.gin`](https://github.com/google-research/t5x/blob/main/t5x/configs/runs/pretrain.gin)), -and expects a few params from you. These params can be specified in a separate -Gin file, or via commandline flags. Following are the required params: - -+ `TRAIN_STEPS`: Number of training steps. For the example run, set this to - `100_000`. -+ `MIXTURE_OR_TASK_NAME`: This is the SeqIO Task or Mixture name to run (from - Step 2). For the example run, set this to `'c4_v220_span_corruption'`. -+ `TASK_FEATURE_LENGTHS`: This is a dict mapping feature key to maximum int - length for that feature. After preprocessing, features are truncated to the - provided value. For the example run, set this to `{"inputs": 512, "targets": - 114}`, following the original T5 pretraining setup. -+ `MODEL_DIR`: A path to write pretrained checkpoints to. When launching using - XManager, this path is automatically set and can be accessed from the - XManager Artifacts page. When running locally using Blaze, you can - explicitly pass a directory using a flag. Launch commands are provided in - the next step. - -In addition to the above params, you will need to import -[`pretrain.gin`](https://github.com/google-research/t5x/blob/main/t5x/configs/runs/pretrain.gin) -and the Gin file for the pretrained model, which for the example run is -[`t5_1_1/small.gin`](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_1/small.gin). - -```gin -include 't5x/configs/runs/pretrain.gin' -include 't5x/examples/t5/t5_1_1/small.gin' -``` - -Note that the `include` statements can use relative paths in this example for -which You will pass an appropriate `gin_search_paths` flag to locate these files -when launching your run. However, we recommend that you use absolute paths -because it can be more difficult to locate the gin files speicified via relative -paths without inspecting the launch command. - -You will also need to import the Python module(s) that register SeqIO Tasks and -Mixtures used in your run. For the example run, we add `import t5.data.mixtures` -since it is where 'glue_v002_proportional' is registered. Note that this module -must also be included as a dependency in the T5X trainer -[binary](https://github.com/google-research/t5x/blob/main/t5x/BUILD;l=74;rcl=398627055). Most -common Task/Mixture modules, such as this one, are already included. If your -module is not included, see the [Advanced Topics section](#custom-t5x-binaries) -at the end of this tutorial for instructions to add it. - -Finally, your Gin file should look like this: - -```gin -include 't5x/examples/t5/t5_1_1/small.gin' -include 't5x/configs/runs/pretrain.gin' - -# Register necessary SeqIO Tasks/Mixtures. -import t5.data.mixtures - -MIXTURE_OR_TASK_NAME = "c4_v220_span_corruption" -TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 114} -TRAIN_STEPS = 10000 -DROPOUT_RATE = 0.0 -BATCH_SIZE = 256 -``` - -See -[`t5x/examples/t5/t5_1_1/examples/small_c4_pretrain.gin`](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/t5_1_1/examples/small_c4_pretrain.gin) -for this example. - - -## Step 4: Launch your experiment - -To launch your experiment locally (for debugging only; larger checkpoints may -cause issues), run the following on commandline: - -```sh -MODEL_DIR="/tmp/pretrain-model/" -python -m t5x.train_unfragmented \ - --gin_file=t5x/examples/t5/t5_1_1/c4_pretrain_small.gin \ - --gin.MODEL_DIR=\"${MODEL_DIR}\" \ - --alsologtostderr -``` - -Note that multiple comma-separated paths can be passed to the `gin_search_paths` -flag, and these paths should contain all Gin files used or included in your -experiment. - - -## Next Steps - -Now that you have successfully pretrained a model, here are some topics you -might want to explore next: - -+ [Fine-tuning a model.](finetune.md) -+ [Evaluating a fine-tuned model.](eval.md) -+ [Running inference on a fine-tuned model.](infer.md) - -We also touch upon a few advanced topics related to pretraining below that might -be useful, especially when customizing your pretraining job. - -## Advanced Topics - -### `train`, `train_eval` {#train-eval .no-toc} - -A -[`DatasetConfig`](https://github.com/google-research/t5x/blob/main/t5x/utils.py?l=113&rcl=375475889) -object is used to configure loading SeqIO Tasks/Mixtures for training and eval. -If you take a closer look at -[`runs/pretrain.gin`](https://github.com/google-research/t5x/blob/main/t5x/configs/runs/pretrain.gin), -you will see that there are two `DatasetConfig` objects defined and passed to -the train function: `train_dataset_cfg` and `train_eval_dataset_cfg`. Here's a -brief description of these configs: - -+ `train`: This configures the Task/Mixture that the model will be pretrained - on. -+ `train_eval`: This configures the Task/Mixture that is used to compute - training metrics on the eval split, e.g. perplexity. These metrics are - defined in the - [`Model`](https://github.com/google-research/t5x/blob/main/t5x/models.py;l=257-266;rcl=394045248) - class and the eval fn is located - [here](https://github.com/google-research/t5x/blob/main/t5x/trainer.py?l=212&rcl=371778063). - -### Deterministic training {.no-toc} - -A training run may consist of various randomized operations, e.g. dataset -shuffling, dropout, etc. However, it is often useful to have deterministic -training, meaning that the random operations are reproducible and robust to -preemption/restarts. To make your pretraining deterministic, in addition to the -params configured in `pretrain.gin`, you need to add the following configs: - -+ sets the dataset seed to a fixed value: `train/utils.DatasetConfig.seed = - 42`. -+ sets the dropout seed to a fixed value: `train_script.train.random_seed = - 42`. -+ enables dataset checkpointing: `utils.SaveCheckpointConfig.save_dataset = - True`. This means that the dataset iterator is checkpointed periodically - during training, and in case of preemptions, training resumes from the - latest dataset checkpoint to ensure deterministic behavior. The - checkpointing frequency is set using `utils.SaveCheckpointConfig.period` - (`1000` by default), meaning that the dataset is checkpointed after - processing `1000` batches (batches, not examples; batch size can be - overridden using `train/DatasetConfig.batch_size` and is set to `128` by - default). - - -### Defining a custom SeqIO Task/Mixture to pretrain on {.no-toc} - -Refer to [SeqIO documentation](https://github.com/google/seqio/blob/main/README.md). diff --git a/t5x-main/pytest.ini b/t5x-main/pytest.ini deleted file mode 100644 index 5f1cd9f4e5ef281a47dc69180674da15e6b28f56..0000000000000000000000000000000000000000 --- a/t5x-main/pytest.ini +++ /dev/null @@ -1,3 +0,0 @@ -[pytest] -python_files = *_test.py -log_level = INFO \ No newline at end of file diff --git a/t5x-main/readthedocs.yaml b/t5x-main/readthedocs.yaml deleted file mode 100644 index d13bca4886c2d072183baf7a5479ed67fa5c342a..0000000000000000000000000000000000000000 --- a/t5x-main/readthedocs.yaml +++ /dev/null @@ -1,31 +0,0 @@ -# .readthedocs.yaml -# Read the Docs configuration file -# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details - -# Required -version: 2 - -# Set the version of Python and other tools you might need -build: - os: ubuntu-22.04 - tools: - python: "3.10" - # You can also specify other tool versions: - # nodejs: "16" - # rust: "1.55" - # golang: "1.17" - -# Build documentation in the docs/ directory with Sphinx -sphinx: - configuration: docs/conf.py - -# If using Sphinx, optionally build your docs in additional formats such as PDF -# formats: -# - pdf - -# Optionally declare the Python requirements required to build your docs -python: - install: - - requirements: docs/requirements.txt - - method: pip - path: . \ No newline at end of file diff --git a/t5x-main/setup.py b/t5x-main/setup.py deleted file mode 100644 index bcd14ea99bbd3f4fb3eecdda4cea25d226c9afb0..0000000000000000000000000000000000000000 --- a/t5x-main/setup.py +++ /dev/null @@ -1,101 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Install T5X.""" - -import os -import sys -import setuptools - -# To enable importing version.py directly, we add its path to sys.path. -version_path = os.path.join(os.path.dirname(__file__), 't5x') -sys.path.append(version_path) -from version import __version__ # pylint: disable=g-import-not-at-top - -# Get the long description from the README file. -with open('README.md') as fp: - _LONG_DESCRIPTION = fp.read() - -_jax_version = '0.4.16' -_jaxlib_version = '0.4.16' - -setuptools.setup( - name='t5x', - version=__version__, - description='T5-eXtended in JAX', - long_description=_LONG_DESCRIPTION, - long_description_content_type='text/markdown', - author='Google Inc.', - author_email='no-reply@google.com', - url='http://github.com/google-research/t5x', - license='Apache 2.0', - packages=setuptools.find_packages(), - package_data={ - '': ['**/*.gin'], # not all subdirectories may have __init__.py. - }, - scripts=[], - install_requires=[ - 'airio @ git+https://github.com/google/airio#egg=airio', - 'absl-py', - 'cached_property', - 'clu @ git+https://github.com/google/CommonLoopUtils#egg=clu', - 'flax @ git+https://github.com/google/flax#egg=flax', - 'fiddle >= 0.2.5', - 'gin-config', - f'jax >= {_jax_version}', - f'jaxlib >= {_jaxlib_version}', - ( - 'jestimator @' - ' git+https://github.com/google-research/jestimator#egg=jestimator' - ), - 'numpy', - 'optax @ git+https://github.com/deepmind/optax#egg=optax', - 'orbax-checkpoint >= 0.5', - 'seqio @ git+https://github.com/google/seqio#egg=seqio', - 'tensorflow-cpu', - 'tensorstore >= 0.1.20', - ], - extras_require={ - 'gcp': [ - 'gevent', - 'google-api-python-client', - 'google-compute-engine', - 'google-cloud-storage', - 'oauth2client', - ], - 'test': ['pytest', 't5'], - # Cloud TPU requirements. - 'tpu': [f'jax[tpu] >= {_jax_version}'], - 'gpu': [ - 'ipdb==0.13.9', - 'fasttext==0.9.2', - 'pysimdjson==5.0.2', - 'pytablewriter==0.64.2', - 'gdown==4.5.3', - 'best-download==0.0.9', - 'lm_dataformat==0.0.20', - 'dllogger@git+https://github.com/NVIDIA/dllogger#egg=dllogger', - 'tfds-nightly', - 't5==0.9.4', - ], - }, - classifiers=[ - 'Development Status :: 4 - Beta', - 'Intended Audience :: Developers', - 'Intended Audience :: Science/Research', - 'License :: OSI Approved :: Apache Software License', - 'Topic :: Scientific/Engineering :: Artificial Intelligence', - ], - keywords='text nlp machinelearning', -) diff --git a/t5x-main/t5x/__init__.py b/t5x-main/t5x/__init__.py deleted file mode 100644 index 70746335adc83e6812fc23a86be55383948fc785..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/__init__.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Import API modules.""" - -import t5x.adafactor -import t5x.checkpoints -import t5x.decoding -import t5x.gin_utils -import t5x.infer -import t5x.losses -import t5x.models -import t5x.partitioning -import t5x.state_utils -import t5x.train_state -import t5x.trainer -import t5x.utils -# Version number. -from t5x.version import __version__ # pylint: disable=g-importing-member - -# TODO(adarob): Move clients to t5x.checkpointing and rename -# checkpoints.py to checkpointing.py -checkpointing = t5x.checkpoints diff --git a/t5x-main/t5x/adafactor.py b/t5x-main/t5x/adafactor.py deleted file mode 100644 index 1920fd5ceacd5aa1722a03337dad2c1ab33230bd..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/adafactor.py +++ /dev/null @@ -1,683 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Adafactor Optimizer. - -Specialized Adafactor implementation for T5X with: - - custom factorization specification rules. - - support for stacked parameters from scanned layers and parameter fusions. - -Why do we need custom factorization? In the Adafactor paper, scalar, vector and -matrix parameters are considered. This is sufficiently general because higher -dimensional parameters can be reshaped. In practice, there are situations where -higher dimensional parameters are desirable. For example, consider the -multi-headed attention. It has projection kernels. This is naturally -represented as 3-dimensional array [d_model, num_head, head_dim]. Keeping the -3-dimensional structure can be beneficial for performance optimization, e.g., by -giving compilers additional degree of freedom to do layout optimization. - -The default heuristic behavior for the second-moment estimator can lead to an -unexpected result because it assumes that the parameters are matrices (vectors -and scalars are not factored). The dimensions are sorted and the smaller -dimension is assigned to the row dim and the larger dim to the col dim (unless -the two largest dims have an equal size and then the original ordering of the -dimensions is used). Then `v_row` (i.e., the optimizer state for the row) is -obtained by removing the col dim. In other words, `rank(v_row) = rank(v) - 1`. -If the parameter is higher dimensional, v_row and v_col are higher dimensional. -Therefore, the outer product of v_row and v_col do not necessarily corresponds -to the row rank approximation that minimizes the generalized Kullback-Leibler -divergence (the original Adafactor formulation). - -This Adafactor implementation generalized the default behavior such that we -obtain the correct second moment estimator even for higher dimensional -parameters. - -""" - -import enum -import re -import typing -from typing import Any, Mapping, Optional, Sequence, Tuple, Union - -from absl import logging -from flax import struct -from flax.core import freeze -from flax.core import FrozenDict -from flax.core import unfreeze -from flax.serialization import from_state_dict -from flax.serialization import to_state_dict -from flax.traverse_util import flatten_dict -from flax.traverse_util import unflatten_dict -import jax -import jax.numpy as jnp -import numpy as np -from t5x import utils -from t5x.optimizers import OptimizerDef -from t5x.optimizers import OptimizerState - -Dtype = Any - - -class FactorDim(enum.Enum): - # Don't factorize this dimension. - NONE = None - # A batch-like dimension that we should not average over. - BATCH = 1 - ROW = 2 - COLUMN = 3 - - -# Sentinel value signifying the legacy heuristic factorization rule. -class HeuristicRule(enum.Enum): - token = 1 - - -HEURISTIC_RULE = HeuristicRule.token -FactorRule = Union[HeuristicRule, Tuple[FactorDim, ...]] - - -def _restore(target, flat): - state_dict = unflatten_dict({tuple(k.split('/')): v for k, v in flat.items()}) - if isinstance(target, FrozenDict): - return freeze(state_dict) - else: - return state_dict - - -def _insert(tpl, idx, x): - tmp = list(tpl) - tmp.insert(idx, x) - return tuple(tmp) - - -def standard_logical_factor_rules(): - return freeze({ - 'vocab': FactorDim.COLUMN, - 'embed': FactorDim.ROW, - 'mlp': FactorDim.COLUMN, - 'heads': FactorDim.COLUMN, - 'kv': FactorDim.COLUMN, - 'joined_kv': FactorDim.COLUMN, - 'relpos_buckets': FactorDim.NONE, - 'layers': FactorDim.BATCH, # used in scanned layers - 'stack': FactorDim.BATCH, # used in stacked params - # 'batch', 'length' should not occur in parameters - 'q_wi_fused': FactorDim.COLUMN, - 'o_wo_fused': FactorDim.COLUMN, - 'multiquery_heads': FactorDim.COLUMN, - 'kv_fused': FactorDim.COLUMN, - 'layer_norm_scale': FactorDim.NONE, - 'mlp_activations': FactorDim.COLUMN, - }) - - -def factor_name_to_factordim(name): - if not isinstance(name, str): - return name - name = name.lower() - return { - 'row': FactorDim.ROW, - 'col': FactorDim.COLUMN, - 'column': FactorDim.COLUMN, - 'batch': FactorDim.BATCH, - 'none': FactorDim.NONE, - 'unfactorized': FactorDim.NONE, - }[name] - - -class HParamMap: - """Maps parameter path names to hparams. - - Names of parameters nested in a PyTree (e.g., an Optimizer) are formed by - joining the names along the path to the parameter leaf with '/'. - """ - - def __init__(self, rules): - self._rules = [(re.compile(r), p) for r, p in rules] - - def __getitem__(self, key: str) -> Any: - for r, p in self._rules: - if r.search(key): - return p - raise KeyError(f'No factor rule found for parameter: {key}') - - def __call__(self, params): - """Returns a copy of the params with mapped hparams in leaves.""" - flat_state_dict = flatten_dict(to_state_dict(params)) - flat_rules_dict = {k: self['/'.join(k)] for k in flat_state_dict.keys()} - return from_state_dict(params, unflatten_dict(flat_rules_dict)) - - -@struct.dataclass -class _AdafactorHyperParams: - """Hparams for Adafactor optimizer.""" - - learning_rate: Optional[float] - factored: bool - multiply_by_parameter_scale: Union[bool, HParamMap] - beta1: Optional[float] - decay_rate: float - step_offset: int - clipping_threshold: Optional[float] - weight_decay_rate: Optional[float] - min_dim_size_to_factor: int - epsilon1: float - epsilon2: float - factor_map: Optional[HParamMap] = None - logical_factor_rules: Any = None - weight_decay_rate_lr_exponent: Optional[float] = None - global_norm_clip_threshold: Optional[float] = None - max_parameter_scale: Optional[float] = None - skip_nan_updates: Optional[bool] = False - - -@struct.dataclass -class _AdafactorParamState: - v_row: np.ndarray # used in normal factored version - v_col: np.ndarray - v: np.ndarray # only used without factoring - m: np.ndarray # only used with momentum - - -class Adafactor(OptimizerDef): - """Adafactor optimizer. - - Adafactor is described in https://arxiv.org/abs/1804.04235. - """ - - def __init__( - self, - learning_rate: Optional[float] = None, - factored: bool = True, - multiply_by_parameter_scale: Union[bool, HParamMap] = True, - beta1: Optional[float] = None, - decay_rate: float = 0.8, - step_offset: int = 0, - clipping_threshold: Optional[float] = 1.0, - weight_decay_rate: Optional[float] = None, - min_dim_size_to_factor: int = 128, - epsilon1: float = 1e-30, - epsilon2: float = 1e-3, - dtype_momentum: Dtype = jnp.float32, - factor_map: Optional[HParamMap] = None, - logical_factor_rules: Optional[Mapping[str, FactorDim]] = None, - weight_decay_rate_lr_exponent: Optional[float] = None, - global_norm_clip_threshold: Optional[float] = None, - max_parameter_scale: Optional[float] = None, - skip_nan_updates: Optional[bool] = False, - ): - """Constructor for the Adafactor optimizer. - - - Args: - learning_rate: float: learning rate. NB: the natural scale for adafactor - LR is markedly different from Adam, one doesn't use the 1/sqrt(hidden) - correction for this optimizer with attention-based models. - factored: boolean: whether to use factored second-moment estimator for 2d - variables. - multiply_by_parameter_scale: boolean: if True, then scale provided - learning_rate by parameter norm. if False, provided learning_rate is - absolute step size. - beta1: an optional float value between 0 and 1, enables momentum and uses - extra memory if non-None! None by default. - decay_rate: float: controls second-moment exponential decay schedule. - step_offset: for finetuning, one may optionally set this to the starting - step-number of the finetuning phase to reset the second moment - accumulators after pretraining. Does not affect the momentum even if it - was used during pretraining. - clipping_threshold: an optional float >= 1, if None no update clipping. - weight_decay_rate: optional rate at which to decay weights. - min_dim_size_to_factor: only factor accumulator if two array dimensions - are at least this size. - epsilon1: Regularization constant for squared gradient. - epsilon2: Regularization constant for parameter scale. - dtype_momentum: dtype of momentum buffers. - factor_map: hparam-map from key path to manual factorization rules. - logical_factor_rules: factorization rules provided as a set of mappings - from logical axis name to ROW, COLUMN, BATCH, or NONE. Supersedes - factor_map if `set_param_axes` is called. - weight_decay_rate_lr_exponent: If present, weight decay rate is computed - as (learning_rate ** weight_decay_rate_lr_exponent). If - weight_decay_rate is also present, then multiply by it. - global_norm_clip_threshold: If set, will clip gradients by global norm - before Adafactor stats are applied. - max_parameter_scale: If set, clips the parameter scale to a maximum value, - which helps prevent parameters from growing without bound. - skip_nan_updates: If set, any parameter that would have been updated by a - NaN value after a applying gradients will be kept with the earlier value - it had. - """ - if not factored and factor_map is not None: - raise ValueError( - 'Adafactor factored is False but factorization rules ' - 'have been provided.' - ) - if not isinstance(multiply_by_parameter_scale, (bool, HParamMap)): - raise TypeError( - '`multiply_by_parameter_scale` must be either bool or `HParamMap` ' - f'type. Got {type(multiply_by_parameter_scale)}' - ) - - if not isinstance(factor_map, (type(None), HParamMap)): - raise TypeError( - '`factor_map` must be either None or `HParamMap` type. Got ' - f'{type(factor_map)}' - ) - - hyper_params = _AdafactorHyperParams( - learning_rate, - factored, - multiply_by_parameter_scale, - beta1, - decay_rate, - step_offset, - clipping_threshold, - weight_decay_rate, - min_dim_size_to_factor, - epsilon1, - epsilon2, - factor_map, - logical_factor_rules, - weight_decay_rate_lr_exponent, - global_norm_clip_threshold, - max_parameter_scale, - skip_nan_updates, - ) - self.dtype_momentum = jax.dtypes.canonicalize_dtype(dtype_momentum) - super().__init__(hyper_params) - - def __eq__(self, other: Any) -> bool: - if not isinstance(other, Adafactor): - return False - return ( - self.hyper_params == other.hyper_params - and self.dtype_momentum == other.dtype_momentum - ) - - def __hash__(self) -> int: - return id(self) - - @staticmethod - def _decay_rate_pow(i: int, exponent: float = 0.8) -> float: - """Default Adafactor second-moment decay schedule.""" - t = jnp.array(i, jnp.float32) + 1.0 - return 1.0 - t ** (-exponent) # pytype: disable=bad-return-type # jnp-type - - @staticmethod - def _parse_rule( - rule: Optional[FactorRule], - shape: Sequence[int], - path: str, - fallback_to_heuristics=True, - ) -> Tuple[ - Tuple[int, ...], - Optional[Union[HeuristicRule, Tuple[Tuple[int, ...], Tuple[int, ...]]]], - ]: - """Parses specification and return factored dims and dims for averaging. - - Adafactor needs to know the two largest dimensions to factorize along. - Traditionally it used a heuristic, but we want finer control over these - factorization dimensions. Additionally, there are situations where - parameters are batched together for e.g. scanned layers and QKV fusion, - and we want to ensure that the scale updates and clipping thresholds are - calculated _within_ each array and not across the entire batched array. - - Args: - rule: the rule is either None (default to heuristic behavior) or a tuple - of the same rank as the `param` array containing a FactorDim.ROW or - FactorDim.COLUMN to mark dimensions to factorize in two row and column - sets, and optionally dimensions marked FactorDim.BATCH to denote batched - dimensions that should not be averaged over. e.g. (BATCH, ROW, COLUMN, - COLUMN) - shape: shape of the variable - path: '/' joined parameter path. - fallback_to_heuristics: whether to fallback to heuristic factorization - rule. For most cases this should be set to `True`. - - Returns: - tuple of: tuple of dimensions to average over, 2-tuple of dimensions to - factorize over. - """ - param_ndim = len(shape) - - if rule is None: - # No factorization. - return tuple(np.arange(param_ndim)), None - - if rule is HEURISTIC_RULE: - if param_ndim > 2: - raise ValueError( - 'A parameter with rank strictly higher than 2 must have an ' - f'explicit factorization rule: {path}, {shape}' - ) - # Even if no explicit rule is provided for the param, we still want to - # average over all the dimensions for computing the RMS scale. - return tuple(np.arange(param_ndim)), HEURISTIC_RULE - - if len(rule) != param_ndim: - raise ValueError( - f'Factorization rule {rule} has incorrect rank ' - f'for param of rank {param_ndim}: {path}, {shape}' - ) - - row_dims = tuple(idx for idx, d in enumerate(rule) if d == FactorDim.ROW) - col_dims = tuple(idx for idx, d in enumerate(rule) if d == FactorDim.COLUMN) - batched_dims = tuple( - idx for idx, d in enumerate(rule) if d == FactorDim.BATCH - ) - averaging_dims = tuple(np.delete(np.arange(param_ndim), batched_dims)) - factor_dims = (row_dims, col_dims) - if factor_dims == ((), ()): - factor_dims = None - - if fallback_to_heuristics and param_ndim <= 2 and not batched_dims: - logging.warning( - 'Since rank of parameter %s %d is less than or equal to 2, the ' - 'factorization method falls back to heuristics and the provided ' - 'factor rule %s is ignored.', - path, - param_ndim, - rule, - ) - return tuple(np.arange(param_ndim)), HEURISTIC_RULE - - return averaging_dims, factor_dims - - def _factored_dims( - self, shape: Sequence[int] - ) -> Optional[Tuple[Tuple[int], Tuple[int]]]: - """Whether to use a factored second moment estimator. - - If there are not two dimensions of size >= min_dim_size_to_factor, then we - do not factor. If we do factor the accumulator, then this function returns a - tuple of the two largest axes to reduce over. - - Args: - shape: a Shape - - Returns: - None or a tuple of ints - """ - if not self.hyper_params.factored or len(shape) < 2: - return None - sorted_dims = np.argsort(shape) - if shape[sorted_dims[-2]] < self.hyper_params.min_dim_size_to_factor: - return None - return (int(sorted_dims[-2]),), (int(sorted_dims[-1]),) - - def init_param_state(self, param, path): - shape = param.shape - state = {k: jnp.zeros((1,)) for k in ['v_row', 'v_col', 'v', 'm']} - if self.hyper_params.factored: - factor_rule = ( - self.hyper_params.factor_map[path] - if self.hyper_params.factor_map - else HEURISTIC_RULE - ) - else: - factor_rule = None - _, factored_dims = self._parse_rule(factor_rule, param.shape, path) - if factored_dims is HEURISTIC_RULE: - factored_dims = self._factored_dims(shape) - if factored_dims is not None: - # We have ruled out the types None and HeuristicRule, so there's only one - # remaining type. (This line is a no-op but is helpful for static type - # analyzers.) - factored_dims = typing.cast( - Tuple[Tuple[int, ...], Tuple[int, ...]], factored_dims - ) - d1, d0 = factored_dims - vr_shape = np.delete(shape, d0) - vc_shape = np.delete(shape, d1) - state['v_row'] = jnp.zeros(vr_shape, dtype=jnp.float32) - state['v_col'] = jnp.zeros(vc_shape, dtype=jnp.float32) - else: - state['v'] = jnp.zeros(param.shape, dtype=jnp.float32) - if self.hyper_params.beta1 is not None: - state['m'] = jnp.zeros(param.shape, dtype=self.dtype_momentum) - return _AdafactorParamState(**state) # pytype: disable=wrong-arg-types # jnp-type - - def init_state(self, params): - params_flat = utils.flatten_dict_string_keys(params) - param_states_flat = [ - self.init_param_state(param, path) - for path, param in params_flat.items() - ] - param_states_flat = { - k: v for k, v in zip(params_flat.keys(), param_states_flat) - } - param_states = _restore(params, param_states_flat) - state = OptimizerState(jnp.asarray(0, dtype=jnp.int32), param_states) - return state - - def apply_param_gradient(self, step, hyper_params, param, state, grad, path): - assert hyper_params.learning_rate is not None, 'no learning rate provided.' - learning_rate = hyper_params.learning_rate - beta1 = hyper_params.beta1 - decay_rate = hyper_params.decay_rate - step_offset = hyper_params.step_offset - multiply_by_parameter_scale = hyper_params.multiply_by_parameter_scale - max_parameter_scale = hyper_params.max_parameter_scale - clipping_threshold = hyper_params.clipping_threshold - weight_decay_rate = hyper_params.weight_decay_rate - epsilon1 = hyper_params.epsilon1 - epsilon2 = hyper_params.epsilon2 - if hyper_params.weight_decay_rate_lr_exponent: - weight_decay_rate = ( - weight_decay_rate or 1.0 - ) * learning_rate**hyper_params.weight_decay_rate_lr_exponent - - if self.hyper_params.factored: - factor_rule = ( - self.hyper_params.factor_map[path] - if self.hyper_params.factor_map - else HEURISTIC_RULE - ) - else: - factor_rule = None - averaging_dims, factored_dims = self._parse_rule( - factor_rule, param.shape, path - ) - - grad = grad.astype(jnp.float32) - - updates = {k: jnp.zeros((1,)) for k in ['v_row', 'v_col', 'v', 'm']} - decay_rate = self._decay_rate_pow(step - step_offset, exponent=decay_rate) - update_scale = learning_rate - - if isinstance(multiply_by_parameter_scale, HParamMap): - multiply_by_parameter_scale = multiply_by_parameter_scale[path] - if multiply_by_parameter_scale: - param_scale = jnp.sqrt( - jnp.mean(param * param, axis=averaging_dims, keepdims=True) - ) - # Clip param_scale to a minimum value of epsilon2. - param_scale = jnp.maximum(param_scale, epsilon2) - # Clip param_scale to a maximum value, if specified. - if max_parameter_scale is not None: - param_scale = jnp.minimum(param_scale, max_parameter_scale) - update_scale *= param_scale - mixing_rate = 1.0 - decay_rate - - grad_sqr = grad * grad + epsilon1 - if factored_dims is HEURISTIC_RULE: - factored_dims = self._factored_dims(param.shape) - if factored_dims is not None: - # We have ruled out the types None and HeuristicRule, so there's only one - # remaining type. (This line is a no-op but is helpful for static type - # analyzers.) - factored_dims = typing.cast( - Tuple[Tuple[int, ...], Tuple[int, ...]], factored_dims - ) - d1, d0 = factored_dims - new_v_row = decay_rate * state.v_row + mixing_rate * jnp.mean( - grad_sqr, axis=d0 - ) - new_v_col = decay_rate * state.v_col + mixing_rate * jnp.mean( - grad_sqr, axis=d1 - ) - updates['v_row'] = new_v_row - updates['v_col'] = new_v_col - reduced_d1 = tuple(d - len([e for e in d0 if e < d]) for d in d1) - - row_col_mean = jnp.mean(new_v_row, axis=reduced_d1, keepdims=True) - row_factor = (new_v_row / row_col_mean) ** -0.5 - col_factor = (new_v_col) ** -0.5 - y = ( - grad - * jnp.expand_dims(row_factor, axis=d0) - * jnp.expand_dims(col_factor, axis=d1) - ) - else: - new_v = decay_rate * state.v + mixing_rate * grad_sqr - updates['v'] = new_v - y = grad * (new_v) ** -0.5 - - if clipping_threshold is not None: - clipping_denom = jnp.maximum( - 1.0, - jnp.sqrt(jnp.mean(y * y, axis=averaging_dims, keepdims=True)) - / clipping_threshold, - ) - y /= clipping_denom - - subtrahend = update_scale * y - if beta1 is not None: - new_m = beta1 * state.m + (1.0 - beta1) * subtrahend - subtrahend = new_m - updates['m'] = new_m.astype(self.dtype_momentum) - - if weight_decay_rate is not None: - new_param = (1.0 - weight_decay_rate) * param - subtrahend - else: - new_param = param - subtrahend - - if hyper_params.skip_nan_updates: - updates['v_row'] = jnp.where( - jnp.isnan(updates['v_row']), state.v_row, updates['v_row'] - ) - updates['v_col'] = jnp.where( - jnp.isnan(updates['v_col']), state.v_col, updates['v_col'] - ) - updates['v'] = jnp.where(jnp.isnan(updates['v']), state.v, updates['v']) - updates['m'] = jnp.where(jnp.isnan(updates['m']), state.m, updates['m']) - new_param = jnp.where(jnp.isnan(new_param), param, new_param) - new_state = _AdafactorParamState(**updates) - - return new_param.astype(param.dtype), new_state - - def apply_gradient(self, hyper_params, params, state, grads): - """Applies a gradient for a set of parameters. - - Args: - hyper_params: a named tuple of hyper parameters. - params: the parameters that should be updated. - state: a named tuple containing the state of the optimizer - grads: the gradient tensors for the parameters. - - Returns: - A tuple containing the new parameters and the new optimizer state. - """ - step = state.step - # We assume that params, param_states, and grads are all dict-like here. - params_flat_dict = utils.flatten_dict_string_keys(params) - params_paths = params_flat_dict.keys() - params_flat = params_flat_dict.values() - # extra paranoia to guarantee identical value ordering - states_flat = utils.flatten_dict_string_keys(state.param_states) - states_flat = [states_flat[k] for k in params_paths] - grads_flat = utils.flatten_dict_string_keys(grads) - grads_flat = [grads_flat[k] for k in params_paths] - - if hyper_params.global_norm_clip_threshold: - # Paper: http://proceedings.mlr.press/v28/pascanu13.pdf - # TF: https://www.tensorflow.org/api_docs/python/tf/clip_by_global_norm - squared_l2_norms = [jnp.sum(jnp.square(g)) for g in grads_flat] - global_norm = jnp.sqrt(jnp.sum(jnp.array(squared_l2_norms))) - scale = hyper_params.global_norm_clip_threshold * jnp.minimum( - 1.0 / hyper_params.global_norm_clip_threshold, 1.0 / global_norm - ) - grads_flat = [g * scale for g in grads_flat] - - out = [ - self.apply_param_gradient(step, hyper_params, param, state, grad, path) - for param, state, grad, path in zip( - params_flat, states_flat, grads_flat, params_paths - ) - ] - - new_params_flat, new_states_flat = list(zip(*out)) if out else ((), ()) - new_params_flat = {k: v for k, v in zip(params_paths, new_params_flat)} - new_states_flat = {k: v for k, v in zip(params_paths, new_states_flat)} - new_params = _restore(params, new_params_flat) - new_param_states = _restore(params, new_states_flat) - new_state = OptimizerState(step + 1, new_param_states) - - return new_params, new_state - - def set_param_axes(self, param_logical_axes): - """Sets Adafactor factorization map from logical axis names tree.""" - logical_factor_rules = self.hyper_params.logical_factor_rules - if logical_factor_rules is None: - return - - # pylint:disable=invalid-name - NONE = FactorDim.NONE - COLUMN = FactorDim.COLUMN - ROW = FactorDim.ROW - - # pylint:enable=invalid-name - - def apply_rules(axes): - # Partially factorized params are marked as unfactorized, preserving - # only BATCH axis annotations. We also check for incompletely factorized - # params that have ROW, COLUMN but also accidental NONE dimensions and - # raise an error in that case. - axis_rules = tuple(logical_factor_rules[x] for x in axes) - axis_rules = tuple(factor_name_to_factordim(x) for x in axis_rules) - if ROW in axis_rules and COLUMN in axis_rules and NONE in axis_rules: - raise ValueError(f'Incomplete adafactor spec {axis_rules} for {axes}!') - if ROW not in axis_rules or COLUMN not in axis_rules: - axis_rules = tuple( - NONE if x in (ROW, COLUMN) else x for x in axis_rules - ) - return axis_rules - - factor_map = jax.tree_util.tree_map(apply_rules, param_logical_axes) - factor_map = utils.flatten_dict_string_keys(factor_map) - - self.hyper_params = self.hyper_params.replace(factor_map=factor_map) - - def derive_logical_axes(self, optimizer_state, param_logical_axes): - """Derives optimizer logical partitioning from model logical partitions.""" - optimizer_logical_axes = jax.tree_util.tree_map( - lambda x: None, optimizer_state.state_dict() - ) - optimizer_logical_axes['target'] = param_logical_axes - - def factor_rule(logical_axes, adafactor_leaf): - return dict( - v_row=None, - v_col=None, - v=logical_axes if adafactor_leaf['v'].shape != (1,) else None, - m=logical_axes if self.hyper_params.beta1 else None, - ) - - optimizer_logical_axes['state']['param_states'] = jax.tree_util.tree_map( - factor_rule, - unfreeze(param_logical_axes), - optimizer_state.state_dict()['state']['param_states'], - ) - - return optimizer_state.restore_state(unfreeze(optimizer_logical_axes)) diff --git a/t5x-main/t5x/adafactor_test.py b/t5x-main/t5x/adafactor_test.py deleted file mode 100644 index 59cada41186f9b00ad0567786b249954755a6a49..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/adafactor_test.py +++ /dev/null @@ -1,587 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for t5x.adafactor.""" - -import functools -import operator -from typing import Sequence - -from absl.testing import absltest -from absl.testing import parameterized -import flax -from flax import traverse_util -import jax -from jax import numpy as jnp -from jax import random -import numpy as np -from t5x import adafactor -from t5x import optimizers - -OptimizerState = optimizers.OptimizerState - -_AdafactorHyperParams = adafactor._AdafactorHyperParams -_AdafactorParamState = adafactor._AdafactorParamState - -_BATCH = adafactor.FactorDim.BATCH -_ROW = adafactor.FactorDim.ROW -_COL = adafactor.FactorDim.COLUMN - -# Testing helpers - - -def _assert_numpy_allclose(a, b, atol=None, rtol=None): - a, b = jnp.array(a), jnp.array(b) - a = a.astype(np.float32) if a.dtype == jnp.bfloat16 else a - b = b.astype(np.float32) if b.dtype == jnp.bfloat16 else b - kw = {} - if atol: - kw['atol'] = atol - if rtol: - kw['rtol'] = rtol - np.testing.assert_allclose(a, b, **kw) - - -def check_eq(xs, ys, atol=None, rtol=None): - xs_leaves, xs_tree = jax.tree_util.tree_flatten(xs) - ys_leaves, ys_tree = jax.tree_util.tree_flatten(ys) - assert xs_tree == ys_tree, f"Tree shapes don't match. \n{xs_tree}\n{ys_tree}" - assert jax.tree_util.tree_all( - jax.tree.map( - lambda x, y: np.array(x).shape == np.array(y).shape, - xs_leaves, - ys_leaves, - ) - ), "Leaves' shapes don't match." - assert jax.tree.map( - functools.partial(_assert_numpy_allclose, atol=atol, rtol=rtol), - xs_leaves, - ys_leaves, - ) - - -def flattened_state_dict(x): - s = flax.serialization.to_state_dict(x) - return flax.traverse_util.flatten_dict(s, sep='/') - - -def tree_shape(x): - return jax.tree.map(jnp.shape, x) - - -def tree_equals(x, y): - return jax.tree_util.tree_all(jax.tree.map(operator.eq, x, y)) - - -def _get_multi_adafactor( - learning_rate: float, - step_offset: int, - adafactor_exclude_from_parameter_scale: Sequence[str], -) -> optimizers.MultiOptimizer: - """Get adafactor with support for excluding some parameters from scaling.""" - - def _should_not_scale(path): - return any([s in path for s in adafactor_exclude_from_parameter_scale]) - - scaled_vars = traverse_util.ModelParamTraversal( - lambda path, _: not _should_not_scale(path) - ) - unscaled_vars = traverse_util.ModelParamTraversal( - lambda path, _: _should_not_scale(path) - ) - scaled_opt = adafactor.Adafactor( - learning_rate, decay_rate=0.8, step_offset=step_offset - ) - unscaled_opt = adafactor.Adafactor( - learning_rate, - decay_rate=0.8, - step_offset=step_offset, - multiply_by_parameter_scale=False, - ) - return optimizers.MultiOptimizer( - ((scaled_vars, scaled_opt), (unscaled_vars, unscaled_opt)) - ) - - -# Inline test data - -MODEL_SHAPE = { - 'decoder': { - 'decoder_norm': {'scale': [128]}, - 'layers_0': { - 'encoder_decoder_attention': { - 'key': {'kernel': [128, 256]}, - 'out': {'kernel': [256, 128]}, - 'query': {'kernel': [128, 256]}, - 'value': {'kernel': [128, 256]}}, - 'mlp': { - 'wi': {'kernel': [128, 512]}, - 'wo': {'kernel': [512, 128]}}, - 'pre_cross_attention_layer_norm': {'scale': [128]}, - 'pre_mlp_layer_norm': {'scale': [128]}, - 'pre_self_attention_layer_norm': {'scale': [128]}, - 'self_attention': { - 'key': {'kernel': [128, 256]}, - 'out': {'kernel': [256, 128]}, - 'query': {'kernel': [128, 256]}, - 'value': {'kernel': [128, 256]}}}, - 'layers_1': { - 'encoder_decoder_attention': { - 'key': {'kernel': [128, 128]}, - 'out': {'kernel': [128, 128]}, - 'query': {'kernel': [128, 128]}, - 'value': {'kernel': [128, 128]}}, - 'mlp': { - 'wi': {'kernel': [128, 512]}, - 'wo': {'kernel': [512, 128]}}, - 'pre_cross_attention_layer_norm': {'scale': [128]}, - 'pre_mlp_layer_norm': {'scale': [128]}, - 'pre_self_attention_layer_norm': {'scale': [128]}, - 'self_attention': { - 'key': {'kernel': [128, 256]}, - 'out': {'kernel': [256, 128]}, - 'query': {'kernel': [128, 256]}, - 'value': {'kernel': [128, 256]}}}, - 'relpos_bias': {'rel_embedding': [2, 32]}}, - 'encoder': { - 'encoder_norm': {'scale': [128]}, - 'layers_0': { - 'attention': { - 'key': {'kernel': [128, 256]}, - 'out': {'kernel': [256, 128]}, - 'query': {'kernel': [128, 256]}, - 'value': {'kernel': [128, 256]}}, - 'mlp': { - 'wi': {'kernel': [128, 512]}, - 'wo': {'kernel': [512, 128]}}, - 'pre_attention_layer_norm': {'scale': [128]}, - 'pre_mlp_layer_norm': {'scale': [128]}}, - 'layers_1': { - 'attention': { - 'key': {'kernel': [128, 256]}, - 'out': {'kernel': [256, 128]}, - 'query': {'kernel': [128, 256]}, - 'value': {'kernel': [128, 256]}}, - 'mlp': { - 'wi': {'kernel': [128, 512]}, - 'wo': {'kernel': [512, 128]}}, - 'pre_attention_layer_norm': {'scale': [128]}, - 'pre_mlp_layer_norm': {'scale': [128]}}, - 'relpos_bias': {'rel_embedding': [2, 32]}}, - 'token_embedder': {'embedding': [32128, 128]}} # pyformat: disable - - -class AdafactorTest(parameterized.TestCase): - - # Classic Adafactor Behavior Tests - - def test_2D_simple(self): - x = {'a': jnp.ones((24, 16))} - opt_def = adafactor.Adafactor(min_dim_size_to_factor=8) - optimizer = opt_def.create(x) - shapes = tree_shape(flattened_state_dict(optimizer.state.param_states)) - ref = {'a/m': (1,), 'a/v': (1,), 'a/v_col': (24,), 'a/v_row': (16,)} - self.assertTrue(tree_equals(shapes, ref)) - - def test_2D_simple_nofactor(self): - x = {'a': jnp.ones((24, 16))} - opt_def = adafactor.Adafactor(min_dim_size_to_factor=32) - optimizer = opt_def.create(x) - shapes = tree_shape(flattened_state_dict(optimizer.state.param_states)) - ref = {'a/m': (1,), 'a/v': (24, 16), 'a/v_col': (1,), 'a/v_row': (1,)} - self.assertTrue(tree_equals(shapes, ref)) - - def test_2D_simple_nofactor_momentum(self): - x = {'a': jnp.ones((24, 16))} - opt_def = adafactor.Adafactor(min_dim_size_to_factor=32, beta1=0.1) - optimizer = opt_def.create(x) - shapes = tree_shape(flattened_state_dict(optimizer.state.param_states)) - ref = {'a/m': (24, 16), 'a/v': (24, 16), 'a/v_col': (1,), 'a/v_row': (1,)} - self.assertTrue(tree_equals(shapes, ref)) - - def test_3D_simple(self): - x = {'a': jnp.ones((24, 4, 16))} - factor_map = adafactor.HParamMap((('a', (_COL, _BATCH, _ROW)),)) - opt_def = adafactor.Adafactor( - min_dim_size_to_factor=8, factor_map=factor_map - ) - optimizer = opt_def.create(x) - shapes = tree_shape(flattened_state_dict(optimizer.state.param_states)) - ref = {'a/m': (1,), 'a/v': (1,), 'a/v_col': (24, 4), 'a/v_row': (4, 16)} - self.assertTrue(tree_equals(shapes, ref)) - - def test_init_state(self): - params = {'x': np.zeros((3, 2))} - optimizer_def = adafactor.Adafactor( - learning_rate=0.1, decay_rate=0.8, beta1=None, min_dim_size_to_factor=0 - ) - state = optimizer_def.init_state(params) - - expected_hyper_params = _AdafactorHyperParams( - 0.1, True, True, None, 0.8, 0, 1.0, None, 0, 1e-30, 1e-3 - ) - self.assertEqual(optimizer_def.hyper_params, expected_hyper_params) - expected_state = OptimizerState( - 0, - { - 'x': _AdafactorParamState( - np.zeros((2,)), np.zeros((3,)), np.zeros((1,)), np.zeros((1,)) - ) - }, - ) - check_eq(state, expected_state) - - # unfactorized - optimizer_def = adafactor.Adafactor( - learning_rate=0.1, decay_rate=0.8, beta1=0.0, min_dim_size_to_factor=32 - ) - state = optimizer_def.init_state(params) - - expected_hyper_params = _AdafactorHyperParams( - 0.1, True, True, 0.0, 0.8, 0, 1.0, None, 32, 1e-30, 1e-3 - ) - self.assertEqual(optimizer_def.hyper_params, expected_hyper_params) - expected_state = OptimizerState( - 0, - { - 'x': _AdafactorParamState( - np.zeros((1,)), - np.zeros((1,)), - np.zeros((3, 2)), - np.zeros((3, 2)), - ) - }, - ) - check_eq(state, expected_state) - - def test_apply_gradient(self): - optimizer_def = adafactor.Adafactor( - learning_rate=0.1, decay_rate=0.8, min_dim_size_to_factor=0 - ) - params = {'x': np.ones((3, 2), np.float32)} - state = OptimizerState( - 1, - { - 'x': _AdafactorParamState( - np.array([0.9, 0.9]), - np.array([0.1, 0.1, 0.1]), - np.zeros((1,)), - np.zeros((1,)), - ) - }, - ) - grads = {'x': np.ones((3, 2), np.float32)} - new_params, new_state = optimizer_def.apply_gradient( - optimizer_def.hyper_params, params, state, grads - ) - expected_new_state = OptimizerState( - 2, - { - 'x': _AdafactorParamState( - np.array([0.9574349, 0.9574349]), - np.array([0.6169143, 0.6169143, 0.6169143]), - np.zeros((1,)), - np.zeros((1,)), - ) - }, - ) - expected_new_params = {'x': 0.9 * np.ones((3, 2))} - check_eq(new_params, expected_new_params) - check_eq(new_state, expected_new_state, rtol=1e-6) - - # unfactored w momentum - optimizer_def = adafactor.Adafactor( - learning_rate=0.1, beta1=0.0, decay_rate=0.8, min_dim_size_to_factor=32 - ) - params = {'x': np.ones((3, 2), np.float32)} - state = OptimizerState( - 1, - { - 'x': _AdafactorParamState( - np.zeros( - 1, - ), - np.zeros( - 1, - ), - 0.5 * np.ones((3, 2)), - np.zeros((3, 2)), - ) - }, - ) - grads = {'x': np.ones((3, 2), np.float32)} - new_params, new_state = optimizer_def.apply_gradient( - optimizer_def.hyper_params, params, state, grads - ) - expected_new_params = {'x': 0.9 * np.ones((3, 2))} - check_eq(new_params, expected_new_params) - expected_new_state = OptimizerState( - 2, - { - 'x': _AdafactorParamState( - np.array([0.0]), - np.array([0.0]), - 0.787174 * np.ones((3, 2)), - 0.1 * np.ones((3, 2)), - ) - }, - ) - check_eq(new_state, expected_new_state, rtol=1e-6) - - def test_apply_gradient_with_global_norm_clipping(self): - optimizer_def = adafactor.Adafactor( - learning_rate=0.1, - decay_rate=0.8, - min_dim_size_to_factor=0, - global_norm_clip_threshold=1.0, - ) - params = {'x': np.ones((3, 2), np.float32)} - state = OptimizerState( - 1, - { - 'x': _AdafactorParamState( - np.array([0.9, 0.9]), - np.array([0.1, 0.1, 0.1]), - np.zeros((1,)), - np.zeros((1,)), - ) - }, - ) - grads = {'x': np.ones((3, 2), np.float32)} - new_params, new_state = optimizer_def.apply_gradient( - optimizer_def.hyper_params, params, state, grads - ) - expected_new_state = OptimizerState( - 2, - { - 'x': _AdafactorParamState( - np.array([0.478811, 0.478811]), - np.array([0.13829, 0.13829, 0.13829]), - np.zeros((1,)), - np.zeros((1,)), - ) - }, - ) - expected_new_params = {'x': 0.9 * np.ones((3, 2))} - check_eq(new_params, expected_new_params) - check_eq(new_state, expected_new_state, rtol=1e-6) - - def test_factorizes(self): - params = {'x': np.zeros((64, 64))} - optimizer_def = adafactor.Adafactor( - learning_rate=0.1, decay_rate=0.8, beta1=None, min_dim_size_to_factor=32 - ) - state = optimizer_def.init_state(params) - self.assertEqual(state.param_states['x'].v.shape, (1,)) - self.assertEqual(state.param_states['x'].m.shape, (1,)) - self.assertEqual(state.param_states['x'].v_row.shape, (64,)) - self.assertEqual(state.param_states['x'].v_col.shape, (64,)) - - params = {'x': np.zeros((31, 64))} - optimizer_def = adafactor.Adafactor( - learning_rate=0.1, decay_rate=0.8, beta1=None, min_dim_size_to_factor=32 - ) - state = optimizer_def.init_state(params) - self.assertEqual(state.param_states['x'].v.shape, (31, 64)) - self.assertEqual(state.param_states['x'].m.shape, (1,)) - self.assertEqual(state.param_states['x'].v_row.shape, (1,)) - self.assertEqual(state.param_states['x'].v_col.shape, (1,)) - - # Manually specified factorization rules tests. - - @parameterized.parameters( - {'rule': (_ROW, _COL)}, - {'rule': (_COL, _ROW)}, - ) - def test_2D_ignore_specified_factor_rule(self, rule): - x = {'a': jnp.ones((24, 16))} - factor_map = adafactor.HParamMap((('a', rule),)) - opt_def = adafactor.Adafactor( - min_dim_size_to_factor=8, factor_map=factor_map - ) - optimizer = opt_def.create(x) - shapes = tree_shape(flattened_state_dict(optimizer.state.param_states)) - # Since param is 2D, the explicit factor rule should be ignored and falls - # back to heuristics where v_row corresponds to the smaller dim. - ref = {'a/m': (1,), 'a/v': (1,), 'a/v_col': (24,), 'a/v_row': (16,)} - self.assertTrue(tree_equals(shapes, ref)) - - def test_3D_simple_manual_rules(self): - x = {'a': jnp.ones((24, 4, 16))} - - factor_map = adafactor.HParamMap((('a', (_COL, _BATCH, _ROW)),)) - opt_def = adafactor.Adafactor( - min_dim_size_to_factor=8, factor_map=factor_map - ) - optimizer = opt_def.create(x) - shapes = tree_shape(flattened_state_dict(optimizer.state.param_states)) - ref = {'a/m': (1,), 'a/v': (1,), 'a/v_col': (24, 4), 'a/v_row': (4, 16)} - self.assertTrue(tree_equals(shapes, ref)) - - factor_map = adafactor.HParamMap((('a', (_ROW, _BATCH, _COL)),)) - opt_def = adafactor.Adafactor( - min_dim_size_to_factor=8, factor_map=factor_map - ) - optimizer = opt_def.create(x) - shapes = tree_shape(flattened_state_dict(optimizer.state.param_states)) - ref = {'a/m': (1,), 'a/v': (1,), 'a/v_col': (4, 16), 'a/v_row': (24, 4)} - self.assertTrue(tree_equals(shapes, ref)) - - factor_map = adafactor.HParamMap((('a', (_COL, _ROW, _ROW)),)) - opt_def = adafactor.Adafactor( - min_dim_size_to_factor=8, factor_map=factor_map - ) - optimizer = opt_def.create(x) - shapes = tree_shape(flattened_state_dict(optimizer.state.param_states)) - ref = {'a/m': (1,), 'a/v': (1,), 'a/v_col': (24,), 'a/v_row': (4, 16)} - self.assertTrue(tree_equals(shapes, ref)) - - factor_map = adafactor.HParamMap((('a', (_COL, _COL, _ROW)),)) - opt_def = adafactor.Adafactor( - min_dim_size_to_factor=8, factor_map=factor_map - ) - optimizer = opt_def.create(x) - shapes = tree_shape(flattened_state_dict(optimizer.state.param_states)) - ref = {'a/m': (1,), 'a/v': (1,), 'a/v_col': (24, 4), 'a/v_row': (16,)} - self.assertTrue(tree_equals(shapes, ref)) - - def test_standard_factor_rules(self): - # one-off test to double-check that we're following the previous - # heuristic convention for rows/columns. - def test_standard_factor_rules(): - token_embedding = (_COL, _ROW) - attn_qkv = (_ROW, _COL) - attn_out = (_COL, _ROW) - mlp_in = (_ROW, _COL) - mlp_out = (_COL, _ROW) - return ( - (r'_layer_norm/(bias|scale)', None), - (r'(encoder|decoder)_norm/(bias|scale)', None), - ( - r'(encoder_decoder_|self_|\b)attention/(query|key|value)/kernel', - attn_qkv, - ), - (r'(encoder_decoder_|self_|\b)attention/out/kernel', attn_out), - (r'mlp/DenseGeneral_\d+/bias', None), - (r'mlp/wi(_\d+)?/kernel', mlp_in), - (r'mlp/wo/kernel', mlp_out), - (r'\brelpos_bias', None), - (r'token_embedder', token_embedding), - (r'.*', adafactor.HEURISTIC_RULE), - ) - - # create fake model parameters - k = jax.random.PRNGKey(0) - params = jax.tree.map( - lambda shape: jax.random.uniform(k, shape), - MODEL_SHAPE, - is_leaf=lambda x: isinstance(x, list), - ) - # make traditional adafactor state with heuristic - factor_map1 = adafactor.HParamMap(((r'.*', adafactor.HEURISTIC_RULE),)) - optimizer_def1 = adafactor.Adafactor( - 0.1, - decay_rate=0.8, - step_offset=0, - multiply_by_parameter_scale=True, - factor_map=factor_map1, - ) - optimizer1 = optimizer_def1.create(params) - # make traditional adafactor state with explicit rules - factor_map2 = adafactor.HParamMap(test_standard_factor_rules()) - optimizer_def2 = adafactor.Adafactor( - 0.1, - decay_rate=0.8, - step_offset=0, - multiply_by_parameter_scale=True, - factor_map=factor_map2, - ) - optimizer2 = optimizer_def2.create(params) - # are they the same? - check_eq(optimizer1.state.param_states, optimizer2.state.param_states) - - @parameterized.parameters( - {'shape': (64, 64)}, - {'shape': (64, 132)}, - {'shape': (132, 64)}, - {'shape': (132, 132)}, - {'shape': (132, 140)}, - {'shape': (140, 132)}, - ) - def test_no_factor_map_equivalence(self, shape): - k = random.PRNGKey(0) - k1, k2 = random.split(k) - p = {'a': random.uniform(k1, shape)} - g = {'a': random.uniform(k2, shape)} - - orig_opt = adafactor.Adafactor(0.1).create(p) - new_opt = adafactor.Adafactor(0.1, factor_map=None).create(p) - check_eq(orig_opt.state_dict(), new_opt.state_dict()) - - orig_opt1 = orig_opt.apply_gradient(g) - new_opt1 = new_opt.apply_gradient(g) - check_eq(orig_opt1.state_dict(), new_opt1.state_dict()) - - @parameterized.parameters( - {'shape': (128, 128), 'rule': (_ROW, _COL)}, - {'shape': (132, 128), 'rule': (_COL, _ROW)}, - {'shape': (128, 132), 'rule': (_ROW, _COL)}, - ) - def test_simple_equivalence(self, shape, rule): - k = random.PRNGKey(0) - k1, k2 = random.split(k) - k3, k4 = random.split(k1) - k5, k6 = random.split(k2) - - p = {'a': random.uniform(k3, shape), 'b': random.uniform(k4, shape)} - g = {'a': random.uniform(k5, shape), 'b': random.uniform(k6, shape)} - - orig_opt = adafactor.Adafactor(0.1).create(p) - factor_map = adafactor.HParamMap( - rules=(('a', rule), ('.*', adafactor.HEURISTIC_RULE)) - ) - new_opt = adafactor.Adafactor(0.1, factor_map=factor_map).create(p) - check_eq(orig_opt.state_dict(), new_opt.state_dict()) - - orig_opt1 = orig_opt.apply_gradient(g) - new_opt1 = new_opt.apply_gradient(g) - check_eq(orig_opt1.state_dict(), new_opt1.state_dict()) - - @parameterized.parameters({'shape': (64, 64)}, {'shape': (132, 132)}) - def test_multiply_by_parameter_scale_equivalence(self, shape): - # Use large parameter values to magnify the parameter scaling effect. - p = {'a': np.random.randn(*shape) * 100, 'b': np.random.randn(*shape) * 100} - g = {'a': np.random.randn(*shape), 'b': np.random.randn(*shape)} - orig_opt = _get_multi_adafactor( - 3.0, 0, adafactor_exclude_from_parameter_scale=('a',) - ).create(p) - scaling_map = adafactor.HParamMap([('a', False), ('.*', True)]) - new_opt = adafactor.Adafactor( - 3.0, multiply_by_parameter_scale=scaling_map - ).create(p) - check_eq(orig_opt.state_dict(), new_opt.state_dict()) - - orig_opt1 = orig_opt.apply_gradient(g) - new_opt1 = new_opt.apply_gradient(g) - check_eq(orig_opt1.state_dict(), new_opt1.state_dict()) - - def test_3d_without_factor_map(self): - x = {'a': jnp.ones((24, 4, 16))} - opt_def = adafactor.Adafactor(factor_map=None) - with self.assertRaises(ValueError): - _ = opt_def.create(x) - - -if __name__ == '__main__': - absltest.main() diff --git a/t5x-main/t5x/assert_gc_disabled_during_import_test_util.py b/t5x-main/t5x/assert_gc_disabled_during_import_test_util.py deleted file mode 100644 index a395327ef46c32ecde49053babf77a99f39d2d03..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/assert_gc_disabled_during_import_test_util.py +++ /dev/null @@ -1,20 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Test utility for disable_gc_during_import_test.py.""" - -import gc - -if gc.isenabled(): - raise ValueError("Expected gc to be disabled; was enabled.") diff --git a/t5x-main/t5x/binary_search.py b/t5x-main/t5x/binary_search.py deleted file mode 100644 index 36e264c31927510ffa967ac43c402ff12fcfb60f..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/binary_search.py +++ /dev/null @@ -1,299 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Binary search over float32 bits. - -Includes fast algorithms top-k masking and top-p masking on probability -distributions. -""" - -from typing import Callable, Sequence - -import jax -from jax import lax -from jax import numpy as jnp - - -def int32_bsearch( - batch_shape: Sequence[int], predicate: Callable[[jnp.ndarray], jnp.ndarray] -): - """Batched binary search over int32 values. - - For each element of the batch, search for the largest int32 (closest to - positive infinity) for which the predicate is False. If the predicate is - always True, returns the minimum int32 value. - - Args: - batch_shape: Shape of the search that we're batching over. - predicate: the query we're searching for. For every batch element, this is - required to be a monotonic function from int32 to bool. In other words, - the predicate must return False for all numbers <= some threshold and then - return True for all numbers > that threshold. The threshold may be - different for different elements of the batch. - - Returns: - For each element of the batch, the largest int32 for which the predicate - returns False. Shape: batch_shape. - """ - current_bits = jnp.zeros(batch_shape, dtype=jnp.int32) - - # bit 31 is special, because it compares in the opposite order of all other - # bits. we use uint32 due to numpy promotion/casting rules. - midpoint = current_bits - predicate_satisfied = predicate(midpoint) - current_bits = current_bits | jnp.where( - predicate_satisfied, jnp.uint32(1 << 31), jnp.uint32(0) - ) - del midpoint, predicate_satisfied - - def loop_body(i, current_bits): - bit_index = 30 - i - bit = jnp.int32(1 << bit_index) - midpoint = current_bits | bit - predicate_satisfied = predicate(midpoint) - current_bits = current_bits | jnp.where( - predicate_satisfied, jnp.int32(0), bit - ) - return current_bits - - current_bits = lax.fori_loop(0, 31, loop_body, current_bits) - return current_bits - - -def _monotonic_int32_to_float32_bit_pattern(x: int) -> int: - """Converts an int32 to a float32 bit pattern with consistent ordering. - - This function is the unique function that is monotonic with respect to the - floating point total order, see - https://en.wikipedia.org/wiki/IEEE_754#Total-ordering_predicate. Note that - this function returns an int32, not a float32. For the function that returns - float32, see `monotonic_int32_to_float32`. - - Args: - x: int bit pattern. - - Returns: - Bit pattern of a float32 number. - """ - non_sign_bits = jnp.int32((1 << 31) - 1) - # See - # https://stackoverflow.com/questions/20097380/iee-754-total-order-in-standard-c11 - # for the relationship between int32 order and f32 total order, including - # the "xor trick". - - # Flip the sort order for numbers where the sign bit is set. On int32, - # the bit pattern with sign bit set and all other bits clear is the most - # negative bit pattern (it's int32::MIN), whereas on float32 it's the least - # negative bit pattern (it's -0.0). Flipping all the non-sign bits makes the - # int32 sort order consistent with the float32 sort order. - x = x ^ jnp.where(x < 0, non_sign_bits, jnp.int32(0)) - return x - - -def _monotonic_int32_to_float32(x: int) -> jax.Array: - """Converts an int32 to a float32 with consistent ordering. - - This function is the unique function that is monotonic with respect to the - floating point total order, see - https://en.wikipedia.org/wiki/IEEE_754#Total-ordering_predicate. - - Args: - x: int bit pattern. - - Returns: - float32 number with consistent ordering. - """ - x = _monotonic_int32_to_float32_bit_pattern(x) - return lax.bitcast_convert_type(x, jnp.float32) - - -def float32_bsearch(batch_shape, predicate): - """Binary search on finite float32 numbers. - - For each element of the batch, this function searches for the largest finite - non-NaN float32 for which the predicate is False. - - Args: - batch_shape: Shape of the search that we're batching over. - predicate: the query we're searching for. This is required to be monotonic - with respect to the floating point order, i.e. it must be False for all - numbers <= a threshold, and then True for all numbers > the threshold. The - threshold may be different for different elements of the batch. - - Returns: - For each element of the batch, the largest float32 for which the predicate - returns False. Shape: f32[batch_shape]. - """ - exponent_bits = jnp.int32((1 << 31) - (1 << (31 - 8))) - - def int32_predicate(x): - x = _monotonic_int32_to_float32_bit_pattern(x) - is_finite = (x & exponent_bits) != exponent_bits - - # Non-finite numbers (infinity and NaN) are at the very extremes of the - # int32 range, i.e. they include int32::MAX and int32::MIN, plus the numbers - # adjacent to them. For the nonfinite numbers touching int32::MIN, we - # arrange for them to return False from the predicate, and for the nonfinite - # numbers touching int32::MAX, we arrange for them to return True from the - # predicate. x>=0 is an easy way to achieve that. - predicate_on_nonfinite = x >= 0 - x_float32 = lax.bitcast_convert_type(x, jnp.float32) - return jnp.where(is_finite, predicate(x_float32), predicate_on_nonfinite) - - # We search over bit patterns, which requires bit shifting and ordering of bit - # patterns. This is natively supported on int32 but not on float32. - # Additionally, it's more common to reason about int32 bit arithmetic and - # ordering than float32 bit arithmetic and ordering, so we do the core of our - # search in int32. Additionally, this allows us to test the underlying binary - # search on int32 values. - # - # The function _monotonic_int32_to_float32 encapsulates all of the knowledge - # we need about float32 bit patterns. - result = int32_bsearch(batch_shape, int32_predicate) - return _monotonic_int32_to_float32(result) - - -def topk_mask(x: jnp.ndarray, k: int, replace_val: jnp.ndarray) -> jnp.ndarray: - """Sets everything to replace_val, except the top k values per batch element. - - Sharding considerations: this function does 32 reductions over the vocab_size - axis of the input array. To avoid excessive latency from these reductions, you - should ensure that the vocab_size axis is unsharded on input to this function. - Prefer to shard the batch axes instead. - - Scratchpad memory considerations: this function is most efficient if the - entire input array can fit in a fast memory tier. To help ensure this, you may - wish to split the batch axes into microbatches and the microbatches in a - sequential loop. - - Args: - x: Values before masking. [batch..., vocab_size] - k: Number of masked values to return. In presence of ties, more than k - values might be returned. - replace_val: For the masked values of x, what to overwrite them with. - - Returns: - masked version of x. [batch..., vocab_size] - """ - batch_shape = tuple(list(x.shape)[:-1]) # [batch...] - - x_for_loop = x - reduce_axis = x.ndim - 1 - if x.ndim > 1: - # We're going to be doing 32 reductions over 'reduce_axis'. Generally, - # reductions over the last dimension are the most expensive, because they - # involve reducing across vector lanes, which is often not efficient. So - # we transpose the reduce_axis to be the second-last dimension, to avoid - # this inefficiency. - # - # Normaly the XLA compiler would automatically perform this optimization, - # but it doesn't yet see through loops to do so. So we do it ourselves. - x_for_loop = jnp.swapaxes(x_for_loop, -1, -2) - reduce_axis = x.ndim - 2 - - # x: [batch..., vocab_size, batch] - def predicate(threshold): - # threshold: [batch...] - - # Since we've negated, we now want a predicate that is True for small - # numbers and False for large numbers. The result of the bsearch is the - # smallest float32 for which the predicate is False. - threshold = -threshold - - threshold = lax.expand_dims(threshold, (reduce_axis,)) - # threshold: [batch..., 1, last_batch] - - # count_ge: [batch...] - count_gt = jnp.sum(x_for_loop > threshold, axis=reduce_axis) - - return count_gt >= k - - # cutoff: [batch...] - cutoff = float32_bsearch(batch_shape, predicate) - cutoff = -cutoff - # cutoff: [batch..., 1] - cutoff = lax.expand_dims(cutoff, (cutoff.ndim,)) - return jnp.where(x >= cutoff, x, jnp.full_like(x, replace_val)) - - -def topp_mask( - logits: jnp.ndarray, p: float, replace_val: jnp.ndarray -) -> jnp.ndarray: - """Applies top-p masking to logits. - - Masks logits down to the smallest set of choices, such that the total - probability mass is >= p. Values in this set are left as they are. All other - values are set with `replace_val`. - - Sharding considerations: this function does 33 reductions over the vocab_size - axis of the input array. To avoid excessive latency from these reductions, you - should ensure that the vocab_size axis is unsharded on input to this function. - Prefer to shard the batch axes instead. - - Scratchpad memory considerations: this function is most efficient if the - entire input array can fit in a fast memory tier. To help ensure this, you may - wish to split the batch axes into microbatches and the microbatches in a - sequential loop. - - Args: - logits: Logits before masking. [batch..., vocab_size] - p: Minimum probability mass requested. - replace_val: For the masked values of logits, what to overwrite them with. - - Returns: - masked version of x. [batch..., vocab_size] - """ - batch_shape = tuple(list(logits.shape)[:-1]) # [batch...] - - probs = jax.nn.softmax(logits, axis=-1) - - probs_for_reduction = probs - reduce_axis = probs_for_reduction.ndim - 1 - if probs_for_reduction.ndim > 1: - # We're going to be doing 33 reductions over 'reduce_axis'. Generally, - # reductions over the last dimension are the most expensive, because they - # involve reducing across vector lanes, which is often not efficient. So - # we transpose the reduce_axis to be the second-last dimension, to avoid - # this inefficiency. - probs_for_reduction = jnp.swapaxes(probs_for_reduction, -1, -2) - reduce_axis = probs_for_reduction.ndim - 2 - - # As we increase the threshold, the probability mass decreases, and the number - # selected decreases. - # - # We want the largest threshold with the probability mass >= p. Binary search - # searches for when the predicate is False, so we negate the output of the - # predicate, i.e. probability mass < p. - - # probs_for_reduction: [batch..., vocab_size, batch] - def predicate(threshold): - # threshold: [batch...] - threshold = lax.expand_dims(threshold, (reduce_axis,)) - # threshold: [batch..., 1, last_batch] - - # count_ge: [batch...] - probability_mass = jnp.sum( - jnp.where(probs_for_reduction >= threshold, probs_for_reduction, 0.0), - axis=reduce_axis, - ) - - return probability_mass < p - - # threshold: [batch...] - threshold = float32_bsearch(batch_shape, predicate) - # threshold: [batch..., 1] - threshold = lax.expand_dims(threshold, (threshold.ndim,)) - return jnp.where( - probs >= threshold, logits, jnp.full_like(logits, replace_val) - ) diff --git a/t5x-main/t5x/binary_search_test.py b/t5x-main/t5x/binary_search_test.py deleted file mode 100644 index 350a837871179970a2cadb1f91e7ac1969b37809..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/binary_search_test.py +++ /dev/null @@ -1,213 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for binary_search.""" - -from absl.testing import absltest -import jax -import jax.numpy as jnp -import numpy as np -from t5x import binary_search - -_INT32_MIN = np.iinfo(np.int32).min -_INT32_MAX = np.iinfo(np.int32).max - - -class BinarySearchTest(absltest.TestCase): - - def test_int32_bsearch(self): - a = jnp.asarray( - [ - 1, - 43, - 79, - 2048, - 0, - 2047, - _INT32_MIN, - _INT32_MIN + 1, - _INT32_MAX, - _INT32_MAX - 1, - ], - dtype=jnp.int32, - ) - - def predicate(x): - return x > a - - r = binary_search.int32_bsearch(a.shape, predicate) - np.testing.assert_array_equal(a, r) - - def test_int32_bsearch_extreme_predicates(self): - def predicate_false(x): - return jnp.full_like(x, False) - - np.testing.assert_array_equal( - jnp.asarray([_INT32_MAX]), - binary_search.int32_bsearch((1,), predicate_false), - ) - - def predicate_true(x): - return jnp.full_like(x, True) - - np.testing.assert_array_equal( - jnp.asarray([_INT32_MIN]), - binary_search.int32_bsearch((1,), predicate_true), - ) - - def test_float32_bsearch(self): - a = jnp.asarray([1.23, 0.0, -0.0, 105.4, -1024, 4.3], dtype=jnp.float32) - - def predicate(x): - return x > a - - c = binary_search.float32_bsearch(a.shape, predicate) - # Given that the predicate is based on floating point '>' as implemented by - # JAX, we need our equality test to be based on floating point '==' as - # implemented by JAX, rather than np.testing.assert_array_equal. - # - # Some corner cases on subnormal numbers may be different, depending on what - # platform we run on. - self.assertTrue(jnp.all(a == c), f'a={a}, c={c}') - - def test_topk_mask(self): - mask = -1e10 - x = jnp.asarray([ - [1.4, 7.9, -4.3, 100, 71, 6, -1e4], - [8.3, 1.2, 1.3, 1.2, 1.2, 9.7, -100], - ]) - - # Using exact equality here, because topk_mask guarantees it: it is just - # masking some things, not doing arithmetic on the array. - np.testing.assert_array_equal( - jnp.asarray([ - [mask, mask, mask, 100, mask, mask, mask], - [mask, mask, mask, mask, mask, 9.7, mask], - ]), - binary_search.topk_mask(x, 1, mask), - ) - np.testing.assert_array_equal( - jnp.asarray([ - [mask, mask, mask, 100, 71, mask, mask], - [8.3, mask, mask, mask, mask, 9.7, mask], - ]), - binary_search.topk_mask(x, 2, mask), - ) - np.testing.assert_array_equal( - jnp.asarray([ - [mask, 7.9, mask, 100, 71, mask, mask], - [8.3, mask, 1.3, mask, mask, 9.7, mask], - ]), - binary_search.topk_mask(x, 3, mask), - ) - np.testing.assert_array_equal( - jnp.asarray([ - [mask, 7.9, mask, 100, 71, 6, mask], - [8.3, 1.2, 1.3, 1.2, 1.2, 9.7, mask], - ]), - binary_search.topk_mask(x, 4, mask), - ) - np.testing.assert_array_equal( - jnp.asarray([ - [1.4, 7.9, mask, 100, 71, 6, mask], - [8.3, 1.2, 1.3, 1.2, 1.2, 9.7, mask], - ]), - binary_search.topk_mask(x, 5, mask), - ) - np.testing.assert_array_equal( - jnp.asarray([ - [1.4, 7.9, -4.3, 100, 71, 6, mask], - [8.3, 1.2, 1.3, 1.2, 1.2, 9.7, mask], - ]), - binary_search.topk_mask(x, 6, mask), - ) - np.testing.assert_array_equal( - jnp.asarray([ - [1.4, 7.9, -4.3, 100, 71, 6, -1e4], - [8.3, 1.2, 1.3, 1.2, 1.2, 9.7, -100], - ]), - binary_search.topk_mask(x, 7, mask), - ) - - def test_topp_mask(self): - probs = jnp.asarray([ - [0.0, 0.7, 0.04, 0.06, 0.2, 0.0], - [0.0, 0.2, 0.2, 0.2, 0.3, 0.1], - ]) - logits = jnp.log(probs) - np.testing.assert_allclose(jax.nn.softmax(logits), probs) - mask = -1e10 - - # Using exact equality here, because topp_mask guarantees it: it is just - # masking some things, not doing arithmetic on the array. - np.testing.assert_array_equal( - jnp.asarray([ - [mask, jnp.log(0.7), mask, mask, mask, mask], - [mask, mask, mask, mask, jnp.log(0.3), mask], - ]), - binary_search.topp_mask(logits, 0.1, mask), - ) - np.testing.assert_array_equal( - jnp.asarray([ - [mask, jnp.log(0.7), mask, mask, mask, mask], - [mask, mask, mask, mask, jnp.log(0.3), mask], - ]), - binary_search.topp_mask(logits, 0.3, mask), - ) - np.testing.assert_array_equal( - jnp.asarray([ - [mask, jnp.log(0.7), mask, mask, mask, mask], - [ - mask, - jnp.log(0.2), - jnp.log(0.2), - jnp.log(0.2), - jnp.log(0.3), - mask, - ], - ]), - binary_search.topp_mask(logits, 0.4, mask), - ) - np.testing.assert_array_equal( - jnp.asarray([ - [mask, jnp.log(0.7), mask, mask, jnp.log(0.2), mask], - [ - mask, - jnp.log(0.2), - jnp.log(0.2), - jnp.log(0.2), - jnp.log(0.3), - mask, - ], - ]), - binary_search.topp_mask(logits, 0.8, mask), - ) - np.testing.assert_array_equal( - jnp.asarray([ - [mask, jnp.log(0.7), mask, jnp.log(0.06), jnp.log(0.2), mask], - [ - mask, - jnp.log(0.2), - jnp.log(0.2), - jnp.log(0.2), - jnp.log(0.3), - jnp.log(0.1), - ], - ]), - binary_search.topp_mask(logits, 0.95, mask), - ) - - -if __name__ == '__main__': - absltest.main() diff --git a/t5x-main/t5x/checkpoint_importer.py b/t5x-main/t5x/checkpoint_importer.py deleted file mode 100644 index 312ecf5df966960c596d31fd1ee63328f3709c1e..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/checkpoint_importer.py +++ /dev/null @@ -1,551 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""T5 Checkpoint Importer.""" - -import abc -import asyncio -from concurrent.futures import thread -import re -from typing import Any, Callable, Mapping, MutableMapping, Optional, Sequence, Tuple, Union - -from flax import traverse_util -import jax -from jax import numpy as jnp -import numpy as np -import tensorflow as tf -import tensorstore as ts - -ArrayType = Union[np.ndarray, jnp.ndarray, jax.Array] -ScalarOrArrayType = Union[int, float, ArrayType] - - -class LazyArray(metaclass=abc.ABCMeta): - """Lazily and asynchronously loads an array. - - LazyArray behaves in the same way as a `numpy` or `jax.numpy` array - while instantiating lazily. All properties, including shape, dtype, and nbytes - are created when the LazyArray is created, but no data is materialized until - `get` or `get_async` are called. Data is materialized using a specified - `get_fn`. - - This class can be used to implement lazy restoration in checkpointing APIs, - where the data is only read from disk when explicitly needed by the user. - """ - - def __init__( - self, shape: Sequence[int], dtype: jnp.dtype, get_fn: Callable[[], Any] - ): - self._shape = tuple(shape) if shape is not None else shape - self._dtype = jnp.dtype(dtype) if dtype is not None else dtype - self._get_fn = get_fn - - @property - def shape(self) -> Tuple[int, ...]: - return self._shape - - @property - def dtype(self) -> jnp.dtype: - return self._dtype - - @property - def nbytes(self) -> int: - return np.prod(self._shape) * self._dtype.itemsize - - def astype(self, dtype: np.dtype) -> 'LazyArray': - return type(self)(self._shape, dtype, self._get_fn) # pytype: disable=not-instantiable - - @abc.abstractmethod - async def get_async(self) -> ScalarOrArrayType: - raise NotImplementedError - - @abc.abstractmethod - def get(self) -> ScalarOrArrayType: - raise NotImplementedError - - def __repr__(self): - return f'{type(self).__name__}(shape={self.shape}, dtype={self.dtype})' - - -# TODO(brianlester): The choice between using a `LazyTreadPoolArray` or a -# `LazyAwaitableArray` is dependent on if the user provided `get_fn` is blocking -# or async respectively, if we can detect which it is, we can automatically -# proxy to the correct subclass. We cannot detect of `get_fn` is a lambda that -# wraps an async call so this isn't possible yet. Add this dispatch once we are -# able to detect that, python3.8+ can detect async for partial'ed functions but -# not lambdas. -class LazyThreadPoolArray(LazyArray): - """Lazily and asynchronously loads an array when the `get_fn` blocks.""" - - # Uses a global threadpool to enable asynchronous loading. - executor = thread.ThreadPoolExecutor() - - async def get_async(self): - return await asyncio.wrap_future(self.executor.submit(self.get)) - - def get(self) -> ScalarOrArrayType: - arr = self._get_fn() - if arr.dtype != self.dtype: - arr = arr.astype(self.dtype) - return arr - - -class LazyAwaitableArray(LazyArray): - """Lazily and asynchronously loads an array when the `get_fn` is async. - - Note: - The synchronous load method `.get` requires the asyncio event loop and - calling `.run_until_complete`. This is not supported when the event loop is - already running (for example, from inside another async function). - - Note: - Currently, this class has a few helper methods for creating a - LazyAwaitableArray when the input could be either an array, or a TensorStore - spec. Most people use async code when dealing with TensorStore so the - classmethods have been placed here. When someone eventually uses a blocking - function to read from TensorStore they can be moved to the LazyArray base - class. - """ - - async def get_async(self): - async def _get_and_cast(): - # Pytype has a false positive here, where it treats our _get_fn (_read_ts - # in this case) as having a return type of `np.ndarray` instead of - # wrapping it in an Awaitable. Related to this bug - # https://github.com/google/pytype/issues/527 - arr = await self._get_fn() # pytype: disable=bad-return-type - if arr.dtype != self.dtype: - arr = arr.astype(self.dtype) - return arr - - return await _get_and_cast() - - def get(self) -> ScalarOrArrayType: - return asyncio.run(self.get_async()) - - @classmethod - def from_tensor_store_spec( - cls, - ts_spec: ts.Spec, - get_fn: Callable[[], np.ndarray], - dtype: Optional[jnp.dtype] = None, - ) -> 'LazyAwaitableArray': - """Create a LazyAwaitableArray based on a tensorstore.Spec.""" - ts_spec = ts_spec.to_json() - shape = ts_spec['metadata']['shape'] - if dtype is None: - dtype = jnp.dtype(ts_spec['dtype']) - else: - dtype = jnp.dtype(dtype) - # v2 T5X checkpoints use uint16 as the TensorStore datatype and then store - # the bfloat16 bytes as in in the 16 bytes uint16 has (no actual cast). When - # When reading the dtype from the TensorStore, if we keep the dtype of these - # v2 checkpoints as np.uint16 then the _get_fn (which has a possible cast to - # support the `restore_dtype` parameter for the checkpointer) will actually - # cast the bfloat16 values to uint16, generally resulting in an array of all - # zeros. This check avoid the actual cast to uint16 by replacing the dtype. - if dtype == np.uint16: - dtype = jnp.bfloat16 - return cls(shape, dtype, get_fn) - - @classmethod - def from_array( - cls, - array: np.ndarray, - get_fn: Callable[[], np.ndarray], - dtype: Optional[jnp.dtype] = None, - ) -> 'LazyAwaitableArray': - """Create a LazyAwaitableArray based on an array or python number.""" - if dtype is None: - dtype = array.dtype - else: - dtype = jnp.dtype(dtype) - return cls(array.shape, dtype, get_fn) - - @classmethod - def from_tensor_store_spec_or_array( - cls, - maybe_ts_spec: Union[ts.Spec, np.ndarray], - get_fn: Callable[[], np.ndarray], - dtype: Optional[jnp.dtype] = None, - ) -> 'LazyAwaitableArray': - """Create a LazyAwaitableArray based on an array or a tensorstore.Spec.""" - if isinstance(maybe_ts_spec, ts.Spec): - return cls.from_tensor_store_spec(maybe_ts_spec, get_fn, dtype=dtype) - return cls.from_array(maybe_ts_spec, get_fn, dtype=dtype) - - -class CheckpointTranslator: - """Utility class for defining mapping rules from one flatdict to another. - - We assume a checkpoint is loaded as a dictionary with flattened keys of the - form: 'name0/name1/name2/.../nameN' - - A rule is added with the 'add' decorator, which takes a regex matching rule - and wraps a conversion function, feeding it (opts, key, val, **regex_groups) - where opts is a dict containing apply-time keyword options for use by the - conversion functions. - """ - - def __init__(self): - self.rules = [] - - def add(self, pattern): - """Adds a new keyval conversion rule. - - Args: - pattern: regex with capture groups for matching given sets of model - variables. We terminate all regexes with '$' to force complete matches. - - Returns: - Translation function decorator for associating with the provided - pattern. - """ - - def register_translation_fn_decorator(fn): - # We force a complete match by adding end-of-string match. - self.rules.append((re.compile(pattern + '$'), fn)) - return fn - - return register_translation_fn_decorator - - def apply(self, flatdict, **opts): - """Applies rules to a flattened dictionary. - - Args: - flatdict: flat-key dictionary of variables. - **opts: additional config options for translation rules supplied at - application time. - - Returns: - Checkpoint data with translated key/values in flat-key dict format. - """ - new_dict = {} - unmatched = {} - for k, v in flatdict.items(): - matched = False - for rule_pat, rule_fn in self.rules: - if rule_pat.match(k): - groups = rule_pat.match(k).groups() - new_k, new_v = rule_fn(opts, k, v, *groups) - if new_k is not None: - new_dict[new_k] = new_v - matched = True - break - if not matched: - unmatched[k] = v - - # We force every key-value pair in checkpoint to have a rule associated with - # it. - if unmatched: - raise ValueError('Unmapped tensor keys exist: %s' % unmatched) - - return new_dict - - -# Create a translation rule set for importing T5 & T5.1.1 model checkpoints. -# ----------------------------------------------------------------------------- -t5_importer = CheckpointTranslator() - -# Name mappings. -SLOT_MAP = {'_slot_vc': 'v_col', '_slot_vr': 'v_row', '_slot_v': 'v'} -TOWER_MAP = {'transformer': 'decoder'} - - -@t5_importer.add(r'global_step') -def global_step(opts, key, val): - del opts, key - return ( - 'state/step', - val.astype(np.int32).get() if isinstance(val, LazyArray) else val, - ) - - -@t5_importer.add(r'shared/embedding(\w*)') -def shared_embeddings(opts, key, val, slot): - del opts, key - prefix = 'state/param_states' if slot else 'target' - suffix = '/' + SLOT_MAP[slot] if slot else '' - newkey = f'{prefix}/token_embedder/embedding{suffix}' - return newkey, val - - -@t5_importer.add(r'(encoder|decoder|transformer)/embedding(\w*)') -def separate_embeddings(opts, key, val, encdec, slot): - del opts, key - prefix = 'state/param_states' if slot else 'target' - suffix = '/' + SLOT_MAP[slot] if slot else '' - encdec = TOWER_MAP.get(encdec, encdec) - newkey = f'{prefix}/{encdec}/token_embedder/embedding{suffix}' - return newkey, val - - -# In the Mesh TensorFlow T5 code, relative_attention_bias always occurs in layer -# 0 because SelfAttention precedes other sublayers within the same block. -@t5_importer.add( - r'(encoder|decoder|transformer)/block_(\d+)/layer_000/SelfAttention/relative_attention_bias(\w*)' -) -def rel_embeddings(opts, key, val, encdec, blocknum, slot): - """Process relpos bias assuming that they are not shared across layers.""" - del opts, key - prefix = 'state/param_states' if slot else 'target' - suffix = '/' + SLOT_MAP[slot] if slot else '' - blocknum = int(blocknum) - encdec = TOWER_MAP.get(encdec, encdec) - # At this point, we can't determine whether the relpos bias was shared across - # layers or not. We first assume that it was not shared. During post - # processing, we remove the layers_0 scope if it was shared. - newkey = ( - f'{prefix}/{encdec}/layers_{blocknum}/relpos_bias/rel_embedding{suffix}' - ) - return newkey, val - - -@t5_importer.add( - r'(encoder|decoder|transformer)/block_(\d+)/layer_\d+/(SelfAttention|EncDecAttention)/(q|k|v|o)(\w*)' -) -def attention_layers(opts, key, val, encdec, blocknum, attntype, qkvo, slot): - """Process attention layers.""" - del opts, key - prefix = 'state/param_states' if slot else 'target' - suffix = '/' + SLOT_MAP[slot] if slot else '' - blocknum = int(blocknum) - encdec = TOWER_MAP.get(encdec, encdec) - matrix = {'q': 'query', 'k': 'key', 'v': 'value', 'o': 'out'}[qkvo] - - if encdec == 'encoder': - attntype = 'attention' - else: - attntype = { - 'SelfAttention': 'self_attention', - 'EncDecAttention': 'encoder_decoder_attention', - }[attntype] - newkey = ( - f'{prefix}/{encdec}/layers_{blocknum}/{attntype}/{matrix}/kernel{suffix}' - ) - return newkey, val - - -@t5_importer.add( - r'(encoder|decoder|transformer)/block_(\d+)/layer_\d+/DenseReluDense/(wi|wo)(?:_(\d+))?/kernel(\w*)' -) -def mlpblock(opts, key, val, encdec, blocknum, io_name, io_num, slot): - """Process MLP blocks.""" - del opts, key - prefix = 'state/param_states' if slot else 'target' - suffix = '/' + SLOT_MAP[slot] if slot else '' - blocknum = int(blocknum) - encdec = TOWER_MAP.get(encdec, encdec) - io_num = f'_{io_num}' if io_num else '' - newkey = f'{prefix}/{encdec}/layers_{blocknum}/mlp/{io_name}{io_num}/kernel{suffix}' - return newkey, val - - -@t5_importer.add( - r'(encoder|decoder|transformer)/block_(\d+)/layer_(\d+)/(?:layer|rms)_norm/scale(\w*)' -) -def layernorms(opts, key, val, encdec, blocknum, lyrnum, slot): - """Process layer norms assuming that they are pre-layernorms.""" - del opts, key - prefix = 'state/param_states' if slot else 'target' - suffix = '/' + SLOT_MAP[slot] if slot else '' - lyrnum = int(lyrnum) - - if encdec == 'transformer': - layernorm_type = ['pre_self_attention_layer_norm', 'pre_mlp_layer_norm'][ - lyrnum - ] - - elif encdec == 'encoder': - layernorm_type = ['pre_attention_layer_norm', 'pre_mlp_layer_norm'][lyrnum] - else: # decoder - layernorm_type = [ - 'pre_self_attention_layer_norm', - 'pre_cross_attention_layer_norm', - 'pre_mlp_layer_norm', - ][lyrnum] - - encdec = TOWER_MAP.get(encdec, encdec) - newkey = ( - f'{prefix}/{encdec}/layers_{int(blocknum)}/{layernorm_type}/scale{suffix}' - ) - return newkey, val - - -@t5_importer.add( - r'(encoder|decoder|transformer)/(?:final_layer|rms)_norm/scale(\w*)' -) -def final_layernorms(opts, key, val, encdec, slot): - """Process final layer norms.""" - del opts, key - prefix = 'state/param_states' if slot else 'target' - suffix = '/' + SLOT_MAP[slot] if slot else '' - norm = { - 'encoder': 'encoder_norm', - 'decoder': 'decoder_norm', - 'transformer': 'decoder_norm', - }[encdec] - encdec = TOWER_MAP.get(encdec, encdec) - newkey = f'{prefix}/{encdec}/{norm}/scale{suffix}' - return newkey, val - - -@t5_importer.add(r'(?:decoder|transformer)/logits/kernel(\w*)') -def final_logits(opts, key, val, slot): - del opts, key - prefix = 'state/param_states' if slot else 'target' - suffix = '/' + SLOT_MAP[slot] if slot else '' - newkey = f'{prefix}/decoder/logits_dense/kernel{suffix}' - return newkey, val - - -def _add_missing_param_states(t5_data): - """Add dummy slots that Flax Adafactor requires but TF does not.""" - updates = {} - for k in t5_data: - if k.startswith('target'): - state_leaf = 'state/param_states' + k[len('target') :] - updates[state_leaf + '/m'] = np.zeros((1,), np.float32) - if state_leaf + '/v' in t5_data: - updates[state_leaf + '/v_row'] = np.zeros((1,), np.float32) - updates[state_leaf + '/v_col'] = np.zeros((1,), np.float32) - elif state_leaf + '/v_row' in t5_data: - updates[state_leaf + '/v'] = np.zeros((1,), np.float32) - t5_data.update(**updates) - return t5_data - - -def _maybe_correct_relpos_bias(t5_data): - """Correct the relpos_bias format if it is shared across layers.""" - max_layer_ind = 0 - for k, v in t5_data.items(): - match = re.search(r'layers_(\d+)/relpos_bias', k) - if match: - layer_ind = int(match.groups()[0]) - max_layer_ind = max(max_layer_ind, layer_ind) - - modified_dict = {} - if max_layer_ind == 0: - # Relative position biases are shared across layers - for k, v in t5_data.items(): - new_k = re.sub(r'layers_\d+/relpos_bias', 'relpos_bias', k) - modified_dict[new_k] = v - else: - # Relative position biases are unique in each layer. No more processing is - # necessary. - modified_dict = t5_data - - return modified_dict - - -# Load checkpoint, translate, and update flax optimizer and model. -# ----------------------------------------------------------------------------- -def load_tf_ckpt(path): - """Load a TF checkpoint as a flat dictionary of numpy arrays.""" - ckpt_reader = tf.train.load_checkpoint(path) - ckpt_shape_map = ckpt_reader.get_variable_to_shape_map() - ckpt_dtype_map = ckpt_reader.get_variable_to_dtype_map() - datamap = { # pylint: disable=g-complex-comprehension - k: LazyThreadPoolArray( - s, - jnp.dtype(ckpt_dtype_map[k].as_numpy_dtype), - lambda x=k: ckpt_reader.get_tensor(x), - ) - for k, s in ckpt_shape_map.items() - } - return datamap - - -def _update_state_dict( - state_dict: Mapping[str, Any], - t5_data: MutableMapping[str, LazyArray], - strict: bool = True, -) -> Mapping[str, Any]: - """Update flax optimizer for T5 model. - - Args: - state_dict: Optimizer to update with T5 parameters. - t5_data: T5 model parameters, typically loaded from a checkpoint. - strict: If True requires that optimizer and t5_data mappings contain the - same set of names (variables). If False, updating will succeed even if - t5_data contains variables not in the optimizer. If the optimizer has - variables not in t5_data, this function will still fail. - - Returns: - Updated optimizer. - """ - flat_state_dict = traverse_util.flatten_dict(state_dict, sep='/') - - # Remove parameters from the checkpoint not found in the optimizer (this - # allows us to load checkpoints that contain more parameters than our current - # model). - if not strict: - for k in list(t5_data): - if k not in flat_state_dict: - t5_data.pop(k) - - # Shape check. - for k, v in t5_data.items(): - if flat_state_dict[k].shape != v.shape: - raise ValueError( - f'Variable {k} has shape {v.shape} != {flat_state_dict[k].shape}' - ) - flat_state_dict = t5_data - state_dict = traverse_util.unflatten_dict( - {tuple(k.split('/')): v for k, v in flat_state_dict.items()} - ) - return state_dict - - -def restore_from_t5_checkpoint( - state_dict: Mapping[str, Any], - path: str, - lazy_parameters: bool = False, - strict: bool = True, - translator: Optional[CheckpointTranslator] = None, -) -> Mapping[str, Any]: - """Load T5 checkpoint and update Adafactor optimizer and T5 model from it. - - We require that the final translated checkpoint structure exactly matches - that of the Flax Adafactor + Transformer data, up to shape agreement of - the leaves. - - Args: - state_dict: Flax Adafactor Optimizer for T5 transformer encoder-decoder. - path: a path to checkpoint file or directory. - lazy_parameters: whether to leave the parameters as LazyArrays to preserve - memory. - strict: If True requires that optimizer and t5_data mappings contain the - same set of names (variables). If False, updating will succeed even if - t5_data contains variables not in the optimizer. If the optimizer has - variables not in t5_data, this function will still fail. - translator: The mapping rules for conversion. If None, then default T5 - conversion rules will be used. - - Returns: - Adafactor optimizer updated with parameters and optimizer state from - T5 checkpoint. - """ - if translator is None: - translator = t5_importer - ckpt_data = load_tf_ckpt(path) - t5_data = translator.apply(ckpt_data) - t5_data = _add_missing_param_states(t5_data) - t5_data = _maybe_correct_relpos_bias(t5_data) - state_dict = _update_state_dict(state_dict, t5_data, strict=strict) - if not lazy_parameters: - state_dict = jax.tree.map( - lambda x: x.get() if isinstance(x, LazyArray) else x, state_dict - ) - return state_dict diff --git a/t5x-main/t5x/checkpoint_importer_test.py b/t5x-main/t5x/checkpoint_importer_test.py deleted file mode 100644 index 5ae0b8ad3cd60360c20e85407e986e1ce2224aa6..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/checkpoint_importer_test.py +++ /dev/null @@ -1,78 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for t5x.checkpoint_importer.""" - -import json -import os - -from absl import flags -from absl.testing import absltest -import jax -import numpy as np -from t5x import checkpoint_importer -import tensorflow as tf - - -class CheckpointImporterTest(absltest.TestCase): - - def test_rel_embeddings_shared_layers(self): - # This represents a ckpt where the Mesh TensorFlow's - # transformer_layers.SelfAttention.relative_attention_type = "bias_shared", - # i.e., the same relative attention parameters are shared by all layers - # within the (en|de)coder. - ckpt_data = { - 'encoder/block_000/layer_000/SelfAttention/relative_attention_bias': 1, - 'decoder/block_000/layer_000/SelfAttention/relative_attention_bias': 2, - 'decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slot_v': ( - 3 - ), - } - t5_data = checkpoint_importer.t5_importer.apply(ckpt_data) - t5_data = checkpoint_importer._maybe_correct_relpos_bias(t5_data) - expected = { - 'target/encoder/relpos_bias/rel_embedding': 1, - 'target/decoder/relpos_bias/rel_embedding': 2, - 'state/param_states/decoder/relpos_bias/rel_embedding/v': 3, - } - self.assertEqual(t5_data, expected) - - def test_rel_embeddings_per_layer(self): - # This represents a ckpt where the Mesh TensorFlow's - # transformer_layers.SelfAttention.relative_attention_type = "bias", i.e., - # each layer has its own relative attention parameters. - ckpt_data = { - 'encoder/block_000/layer_000/SelfAttention/relative_attention_bias': 1, - 'encoder/block_001/layer_000/SelfAttention/relative_attention_bias': 2, - 'decoder/block_000/layer_000/SelfAttention/relative_attention_bias': 3, - 'decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slot_v': ( - 4 - ), - 'decoder/block_011/layer_000/SelfAttention/relative_attention_bias': 5, - } - t5_data = checkpoint_importer.t5_importer.apply(ckpt_data) - t5_data = checkpoint_importer._maybe_correct_relpos_bias(t5_data) - expected = { - 'target/encoder/layers_0/relpos_bias/rel_embedding': 1, - 'target/encoder/layers_1/relpos_bias/rel_embedding': 2, - 'target/decoder/layers_0/relpos_bias/rel_embedding': 3, - 'state/param_states/decoder/layers_0/relpos_bias/rel_embedding/v': 4, - 'target/decoder/layers_11/relpos_bias/rel_embedding': 5, - } - self.assertEqual(t5_data, expected) - - - -if __name__ == '__main__': - absltest.main() diff --git a/t5x-main/t5x/checkpoint_utils.py b/t5x-main/t5x/checkpoint_utils.py deleted file mode 100644 index 4ed726df3db32fe7ac86a781d3551639abe10727..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/checkpoint_utils.py +++ /dev/null @@ -1,302 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Checkpoint helper functions for managing checkpoints. - -Supports marking checkpoints as pinned to exclude them from the checkpointer -removal process. -""" - -import enum -import os -from typing import Any, BinaryIO, Optional, Union - -from absl import logging -from etils import epath -import msgpack -import orbax.checkpoint as ocp -from tensorflow.io import gfile - - -# PINNED file in the checkpoint directory indicates that the checkpoint should -# not be removed during the automatic pruning of old checkpoints. -_PINNED_CHECKPOINT_FILENAME = 'PINNED' - -PyTree = Any - - -def pinned_checkpoint_filepath(ckpt_dir: str) -> str: - """Full path of the pinned checkpoint file.""" - return os.path.join(ckpt_dir, _PINNED_CHECKPOINT_FILENAME) - - -def is_pinned_checkpoint(ckpt_dir: str) -> bool: - """Returns whether the checkpoint is pinned, and should NOT be removed.""" - pinned_ckpt_file = pinned_checkpoint_filepath(ckpt_dir) - if gfile.exists(pinned_ckpt_file): - return True - return False - - -def pin_checkpoint(ckpt_dir: str, txt: str = '1') -> None: - """Pin a checkpoint so it does not get deleted by the normal pruning process. - - Creates a PINNED file in the checkpoint directory to indicate the checkpoint - should be excluded from the deletion of old checkpoints. - - Args: - ckpt_dir: The checkpoint step dir that is to be always kept. - txt: Text to be written into the checkpoints ALWAYS_KEEP me file. - """ - pinned_ckpt_file = pinned_checkpoint_filepath(ckpt_dir) - with gfile.GFile(pinned_ckpt_file, 'w') as f: - logging.debug('Write %s file : %s.', pinned_ckpt_file, txt) - f.write(txt) - - -def unpin_checkpoint(ckpt_dir: str) -> None: - """Removes the pinned status of the checkpoint so it is open for deletion.""" - if not is_pinned_checkpoint(ckpt_dir): - logging.debug('%s is not PINNED. Nothing to do here.', ckpt_dir) - return - try: - pinned_ckpt_file = pinned_checkpoint_filepath(ckpt_dir) - logging.debug('Remove %s file.', pinned_ckpt_file) - gfile.rmtree(pinned_ckpt_file) - except IOError: - logging.exception('Failed to unpin %s', ckpt_dir) - - -def remove_checkpoint_dir(ckpt_dir: str) -> None: - """Removes the checkpoint dir if it is not pinned.""" - if not is_pinned_checkpoint(ckpt_dir): - logging.info('Deleting checkpoint: %s', ckpt_dir) - gfile.rmtree(ckpt_dir) - else: - logging.info('Keeping pinned checkpoint: %s', ckpt_dir) - - -def remove_dataset_checkpoint(ckpt_dir: str, train_ds_prefix: str) -> None: - """Removes dataset checkpoints if the checkpoint is not pinned.""" - if not is_pinned_checkpoint(ckpt_dir): - train_ds_pattern = os.path.join(ckpt_dir, train_ds_prefix + '*') - logging.info('Deleting dataset checkpoint: %s', train_ds_pattern) - for file in gfile.glob(train_ds_pattern): - gfile.remove(file) - else: - logging.info('Keeping pinned checkpoint: %s', ckpt_dir) - - -def _read_msgpack_keys(file_like: BinaryIO) -> PyTree: - """Returns a tree containing all keys but no values from a msgpack file.""" - unpacker = msgpack.Unpacker(file_like) - num_keys = unpacker.read_map_header() - ret = {} - - # Contains references to the parent tree for each key to visit in the - # msgpack file traversal. - visit_stack = [ret for _ in range(num_keys)] - while visit_stack: - parent_dict = visit_stack.pop() - key = unpacker.unpack() - if isinstance(key, bytes): - key = str(unpacker.unpack(), 'utf-8') - - # Check if the value object is map. - try: - n = unpacker.read_map_header() - ref = parent_dict[key] = {} - visit_stack.extend(ref for _ in range(n)) - except msgpack.UnpackValueError: - # Not a map so skip unpacking the value object and record the current key. - unpacker.skip() - parent_dict[key] = None - - return ret - - -def _contains_ts_spec(tree: PyTree) -> bool: - """Returns whether the a Pytree contains a serialized ts.Spec object.""" - to_visit = [tree] - while to_visit: - cur = to_visit.pop() - if cur.keys() >= {'driver', 'kvstore', 'metadata'}: - return True - to_visit.extend(v for v in cur.values() if isinstance(v, dict)) - return False - - -# Constant copied from orbax/checkpoint/pytree_checkpoint_handler.py -_METADATA_FILE = '_METADATA' - - -def _contains_orbax_metadata(ckpt_path: str) -> bool: - metadata = os.path.join(os.path.dirname(ckpt_path), _METADATA_FILE) - return gfile.exists(metadata) - - -class CheckpointTypes(enum.Enum): - ORBAX = 'ORBAX' - T5X = 'T5X' - T5X_TF = 'T5X_TF' - - -def _warn_if_unexpected_type( - checkpoint_path, checkpoint_type, expected, extra_warn_log -): - """Warns the user if unexpected type found.""" - if expected is None or checkpoint_type == expected: - return - - logging.warning( - 'Expected the checkpoint at %s to be %s format, but' - ' the actual detected format was %s.', - checkpoint_path, - expected, - checkpoint_type, - ) - logging.warning(extra_warn_log) - - -def detect_checkpoint_type( - checkpoint_path: epath.PathLike, expected: Optional[CheckpointTypes] = None -) -> CheckpointTypes: - """Returns the checkpoint type by reading the `.checkpoint` metadata file. - - Args: - checkpoint_path: The path of the `.checkpoint` file. - expected: The expected checkpoint type. If the checkpoint type is not as - expected, this function will log a warning but will not raise an error. - - Returns: - The checkpoint type. - """ - if _contains_orbax_metadata(checkpoint_path): - checkpoint_type = CheckpointTypes.ORBAX - _warn_if_unexpected_type( - checkpoint_path, - checkpoint_type, - expected, - f'Found `{_METADATA_FILE}` in the checkpoint directory, which only ' - 'appears in Orbax checkpoints', - ) - return checkpoint_type - - with gfile.GFile(checkpoint_path, 'rb') as fp: - raw_contents = fp.read(21) - if raw_contents == b'model_checkpoint_path': - checkpoint_type = CheckpointTypes.T5X_TF - _warn_if_unexpected_type( - checkpoint_path, - checkpoint_type, - expected, - 'The checkpoint file was not a msgpack, and had the string ' - '"model_checkpoint_path", so it was assumed to be in the T5X ' - 'TensorFlow format.', - ) - return checkpoint_type - - # Assume that if the msgpack file has exactly 'version' and 'optimizer' as - # keys, it is a T5X checkpoint. Checkpoints that were created a long time - # ago may not contain these keys, so there is a backup ts.Spec check - # as well. - fp.seek(0) - key_tree = _read_msgpack_keys(fp) - if set(key_tree.keys()) == {'version', 'optimizer'}: - checkpoint_type = CheckpointTypes.T5X - _warn_if_unexpected_type( - checkpoint_path, - checkpoint_type, - expected, - 'Top-level keys in the msgpack file were "version" and "optimizer", ' - 'thus the checkpoint was assumed to be in the T5X format.', - ) - return checkpoint_type - elif _contains_ts_spec(key_tree): - # If the checkpoint contains a ts.Spec, it could either be a T5X - # checkpoint or an early version Flax checkpoint. The latter is - # essentially deprecated but should also be handled by the T5X - # Checkpointer, so we return T5X here for simplicity. - checkpoint_type = CheckpointTypes.T5X - _warn_if_unexpected_type( - checkpoint_path, - checkpoint_type, - expected, - 'Found ts.Spec in the checkpoint msgpack file, thus the checkpoint' - ' was assumed to be in the T5X format.', - ) - return checkpoint_type - else: - checkpoint_type = CheckpointTypes.T5X - _warn_if_unexpected_type( - checkpoint_path, - checkpoint_type, - expected, - 'Did not detect ts.Spec nor the {"version", "optimizer"} keys in the' - 'checkpoint msgpack file, so the checkpoint was assumed to be ' - 'written with T5X.', - ) - return checkpoint_type - - -def _is_supported_empty_value(value: Any) -> bool: - if hasattr(ocp.type_handlers, 'is_supported_empty_aggregation_type'): - return ocp.type_handlers.is_supported_empty_aggregation_type(value) - return ocp.type_handlers.is_supported_empty_value(value) - - -def get_restore_parameters(directory: epath.Path, structure: PyTree) -> PyTree: - """Construct ParamInfos tree needed for restoration. - - ParamInfos are constructed from the structure of the original checkpoint. - - Args: - directory: Checkpoint directory. - structure: The structure of the original checkpoint. - - Returns: - PyTree of `ParamInfo`. - """ - flat_structure = ocp.tree.to_flat_dict(structure, keep_empty_nodes=True) - param_names = ocp.tree.get_param_names(structure) - flat_param_names = ocp.tree.to_flat_dict(param_names, keep_empty_nodes=True) - flat_param_infos = {} - is_ocdbt_checkpoint = ocp.type_handlers.is_ocdbt_checkpoint(directory) - ts_context = ocp.type_handlers.get_ts_context() - - def _get_param_info( - name: str, - meta_or_value: Union[Any, ocp.metadata.tree.ValueMetadataEntry], - ) -> Union[ocp.type_handlers.ParamInfo, Any]: - if _is_supported_empty_value(meta_or_value): - # Empty node, ParamInfo should not be returned. - return meta_or_value - elif not isinstance(meta_or_value, ocp.metadata.tree.ValueMetadataEntry): - # Aggregated value. - skip_deserialize = True - else: - skip_deserialize = meta_or_value.skip_deserialize - return ocp.type_handlers.ParamInfo( - name=name, - path=directory / name, - parent_dir=directory, - skip_deserialize=skip_deserialize, - is_ocdbt_checkpoint=is_ocdbt_checkpoint, - ts_context=ts_context, - ) - - for key, meta in flat_structure.items(): - flat_param_infos[key] = _get_param_info(flat_param_names[key], meta) - - return ocp.tree.from_flat_dict(flat_param_infos, target=structure) diff --git a/t5x-main/t5x/checkpoint_utils_test.py b/t5x-main/t5x/checkpoint_utils_test.py deleted file mode 100644 index e91b35fd26be1ada91e1dfd56b6e4d61365a8f8e..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/checkpoint_utils_test.py +++ /dev/null @@ -1,193 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for t5x.checkpoint_utils.""" - -import os -import traceback - -from absl.testing import absltest -from t5x import checkpoint_utils -from tensorflow.io import gfile - -TESTDATA = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata") - - -class CheckpointsUtilsTest(absltest.TestCase): - - def setUp(self): - super().setUp() - self.checkpoints_dir = self.create_tempdir() - self.ckpt_dir_path = self.checkpoints_dir.full_path - self.pinned_ckpt_file = os.path.join(self.ckpt_dir_path, "PINNED") - self.checkpoints_dir.create_file("checkpoint") - # Create a `train_ds` file representing the dataset checkpoint. - train_ds_basename = "train_ds-00000-of-00001" - self.train_ds_file = os.path.join(self.ckpt_dir_path, train_ds_basename) - self.checkpoints_dir.create_file(train_ds_basename) - - def test_always_keep_checkpoint_file(self): - self.assertEqual( - "/path/to/ckpt/dir/PINNED", - checkpoint_utils.pinned_checkpoint_filepath("/path/to/ckpt/dir"), - ) - - def test_is_pinned_checkpoint_false_by_default(self): - # Ensure regular checkpoint without PINNED file. - self.assertFalse(gfile.exists(os.path.join(self.ckpt_dir_path, "PINNED"))) - - # Validate checkpoints are not pinned by default. - self.assertFalse(checkpoint_utils.is_pinned_checkpoint(self.ckpt_dir_path)) - - def test_is_pinned_checkpoint(self): - # Ensure the checkpoint directory as pinned. - pinned_ckpt_testdata = os.path.join(TESTDATA, "pinned_ckpt_dir") - pinned_file = os.path.join(pinned_ckpt_testdata, "PINNED") - self.assertTrue(gfile.exists(pinned_file)) - - # Test and validate. - self.assertTrue(checkpoint_utils.is_pinned_checkpoint(pinned_ckpt_testdata)) - - def test_is_pinned_missing_ckpt(self): - self.assertFalse( - checkpoint_utils.is_pinned_checkpoint( - os.path.join(self.ckpt_dir_path, "ckpt_does_not_exist") - ) - ) - - def test_pin_checkpoint(self): - # Ensure directory isn't already pinned. - self.assertFalse(gfile.exists(self.pinned_ckpt_file)) - - # Test. - checkpoint_utils.pin_checkpoint(self.ckpt_dir_path) - - # Validate. - self.assertTrue(gfile.exists(self.pinned_ckpt_file)) - with open(self.pinned_ckpt_file) as f: - self.assertEqual("1", f.read()) - - def test_pin_checkpoint_txt(self): - checkpoint_utils.pin_checkpoint(self.ckpt_dir_path, "TEXT_IN_PINNED") - self.assertTrue(os.path.exists(os.path.join(self.ckpt_dir_path, "PINNED"))) - with open(self.pinned_ckpt_file) as f: - self.assertEqual("TEXT_IN_PINNED", f.read()) - - def test_unpin_checkpoint(self): - # Mark the checkpoint directory as pinned. - self.checkpoints_dir.create_file("PINNED") - self.assertTrue(checkpoint_utils.is_pinned_checkpoint(self.ckpt_dir_path)) - - # Test. - checkpoint_utils.unpin_checkpoint(self.ckpt_dir_path) - - # Validate the "PINNED" checkpoint file got removed. - self.assertFalse(gfile.exists(os.path.join(self.ckpt_dir_path, "PINNED"))) - - def test_unpin_checkpoint_does_not_exist(self): - missing_ckpt_path = os.path.join(self.ckpt_dir_path, "ckpt_does_not_exist") - self.assertFalse(gfile.exists(missing_ckpt_path)) - - # Test. Assert does not raise error. - try: - checkpoint_utils.unpin_checkpoint(missing_ckpt_path) - except IOError: - # TODO(b/172262005): Remove traceback.format_exc() from the error message. - self.fail("Unpin checkpoint failed with: %s" % traceback.format_exc()) - - def test_remove_checkpoint_dir(self): - # Ensure the checkpoint directory is setup. - assert gfile.exists(self.ckpt_dir_path) - - # Test. - checkpoint_utils.remove_checkpoint_dir(self.ckpt_dir_path) - - # Validate the checkpoint directory got removed. - self.assertFalse(gfile.exists(self.ckpt_dir_path)) - - def test_remove_checkpoint_dir_pinned(self): - # Mark the checkpoint directory as pinned so it does not get removed. - self.checkpoints_dir.create_file("PINNED") - - # Test. - checkpoint_utils.remove_checkpoint_dir(self.ckpt_dir_path) - - # Validate the checkpoint directory still exists. - self.assertTrue(gfile.exists(self.ckpt_dir_path)) - - def test_remove_dataset_checkpoint(self): - # Ensure the checkpoint directory is setup. - assert gfile.exists(self.ckpt_dir_path) - - # Test. - checkpoint_utils.remove_dataset_checkpoint(self.ckpt_dir_path, "train_ds") - - # Validate the checkpoint directory got removed. - self.assertFalse(gfile.exists(self.train_ds_file)) - self.assertTrue(gfile.exists(self.ckpt_dir_path)) - - def test_remove_dataset_checkpoint_pinned(self): - # Mark the checkpoint directory as pinned so it does not get removed. - self.checkpoints_dir.create_file("PINNED") - - # Test. - checkpoint_utils.remove_dataset_checkpoint(self.ckpt_dir_path, "train_ds") - - # Validate the checkpoint directory still exists. - self.assertTrue(gfile.exists(self.train_ds_file)) - self.assertTrue(gfile.exists(self.ckpt_dir_path)) - - def test_detect_checkpoint_type(self): - tf_ckpt = os.path.join(TESTDATA, "mtf_tiny_t5", "checkpoint") - orbax_ckpt = os.path.join(TESTDATA, "tiny_orbax", "1", "checkpoint") - t5_ckpt = os.path.join(TESTDATA, "tiny_t5", "checkpoint_1", "checkpoint") - - ret = checkpoint_utils.detect_checkpoint_type( - t5_ckpt, expected=checkpoint_utils.CheckpointTypes.T5X - ) - self.assertEqual(ret, checkpoint_utils.CheckpointTypes.T5X) - - ret = checkpoint_utils.detect_checkpoint_type( - tf_ckpt, expected=checkpoint_utils.CheckpointTypes.T5X_TF - ) - self.assertEqual(ret, checkpoint_utils.CheckpointTypes.T5X_TF) - - ret = checkpoint_utils.detect_checkpoint_type( - orbax_ckpt, expected=checkpoint_utils.CheckpointTypes.ORBAX - ) - self.assertEqual(ret, checkpoint_utils.CheckpointTypes.T5X) - - with self.assertLogs(level="WARN") as log_output: - checkpoint_utils.detect_checkpoint_type( - tf_ckpt, expected=checkpoint_utils.CheckpointTypes.T5X - ) - self.assertRegex( - log_output[0][0].message, - ".*to be CheckpointTypes.T5X format, but the actual detected format was" - " CheckpointTypes.T5X_TF.*", - ) - - with self.assertLogs(level="WARN") as log_output: - checkpoint_utils.detect_checkpoint_type( - orbax_ckpt, expected=checkpoint_utils.CheckpointTypes.T5X_TF - ) - self.assertRegex( - log_output[0][0].message, - ".*to be CheckpointTypes.T5X_TF format, but the actual detected format" - " was CheckpointTypes.T5X.*", - ) - - -if __name__ == "__main__": - absltest.main() diff --git a/t5x-main/t5x/checkpoints.py b/t5x-main/t5x/checkpoints.py deleted file mode 100644 index ada1b4e7af2a5cf3889be8473c85be6024598540..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/checkpoints.py +++ /dev/null @@ -1,2575 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Utilities for reading and writing sharded checkpoints. - -The checkpointing utilities here can be used in two ways. The first is to use -the `Checkpointer` class. This requires having an optimizer and various -partitioning utilities setup, but allows for reading and writing of partitioned -parameters. It also allows different hosts to read different parameter -partitions in a multi-host setup, which results in much faster reads. This is -normally used during training where you have already created an optimizer based -on a config. - -The second way is to use the `load_t5x_checkpoint` function. This doesn't -require an optimizer to get given up front so it is useful for things like -debugging and analysis of learned weights. However, this means that we cannot do -partitioned reads so loading will be slower than that `Checkpointer` class. -""" - -import asyncio -import dataclasses -import functools -import os -import re -import subprocess -import time -from typing import Any, Dict, Iterable, List, Mapping, MutableMapping, Optional, Sequence, Tuple, Union - -from absl import logging -import clu.data -from etils import epath -import flax -from flax import serialization -from flax import traverse_util -import gin -import jax -from jax import monitoring -from jax.experimental import multihost_utils -from jax.experimental.array_serialization import serialization as array_serialization -import jax.numpy as jnp -import numpy as np -import orbax.checkpoint as ocp -from t5x import checkpoint_importer -from t5x import checkpoint_utils -from t5x import checkpoints_utils -from t5x import optimizers -from t5x import partitioning -from t5x import state_utils -from t5x import train_state as train_state_lib -# pylint: disable=g-importing-member -from t5x.checkpoints_utils import all_dataset_checkpoint_steps -from t5x.checkpoints_utils import all_steps -from t5x.checkpoints_utils import get_checkpoint_dir -from t5x.checkpoints_utils import get_checkpoint_prefix -from t5x.checkpoints_utils import get_step_from_checkpoint_dir -from t5x.checkpoints_utils import latest_step -# pylint: enable=g-importing-member -import tensorflow as tf -from tensorflow.io import gfile -import tensorstore as ts -import typing_extensions - -from tensorboard.backend.event_processing import directory_watcher -from tensorboard.backend.event_processing import event_file_loader -from tensorboard.backend.event_processing import io_wrapper - -PartitionSpec = partitioning.PartitionSpec -PyTree = Any -PyTreeDef = jax.tree_util.PyTreeDef -LazyArray = checkpoint_importer.LazyArray -LazyAwaitableArray = checkpoint_importer.LazyAwaitableArray -LazyThreadPoolArray = checkpoint_importer.LazyThreadPoolArray -Dataset = Union[ - tf.data.Iterator, clu.data.dataset_iterator.DatasetIterator, None -] - -# Version 3 is used since 2021-06-10, compared to version 2 the only change is -# that `bfloat16` arrays are written in Tensorstore using its native `bfloat16` -# support instead of casting them to `uint16`. -VERSION = 3 -# Desired chunk size is 64MiB. -# This is large enough to keep CNS happy but small enough to support a wide -# range of partitionings. -_DESIRED_CHUNK_SIZE_BYTES = 64 * 1024 * 1024 -# TODO(levskaya, adarob): how should we handle stacked/fused variables?? -_TRAIN_DS_PREFIX = checkpoints_utils._TRAIN_DS_PREFIX # pylint: disable=protected-access -_READ_CHECKPOINT_EVENT: str = '/jax/checkpoint/read/durations_sec' -_WRITE_CHECKPOINT_EVENT: str = '/jax/checkpoint/write/durations_sec' -_TS_CONTEXT = ts.Context({'file_io_concurrency': {'limit': 128}}) - - -def _choose_chunk_shape( - write_shape: Sequence[int], target_elements: int -) -> List[int]: - """Chooses a chunk shape that evenly divides write_shape. - - The chunk shape is chosen such that the total number of elements is less than - or equal to `target_elements`, but is otherwise as large as possible. - - This uses a greedy algorithm that attempts to split the largest dimensions - first. - - Args: - write_shape: Write shape for which to choose a chunk shape. - target_elements: Desired number of elements in chosen chunk shape. Must be - >= 1. - - Returns: - List of length `len(write_shape)` specifying the chosen chunk shape. - """ - assert target_elements >= 1 - rank = len(write_shape) - - # `dim_factors[i]` is the list of divisors of `write_shape[i]` - dim_factors = [ - [i for i in range(1, size + 1) if size % i == 0] for size in write_shape - ] - - # The current chunk shape is: - # [dim_factors[i][-1] for i in range(rank)] - - def get_total_elements(): - """Returns the number of elements in the current chunk shape.""" - total_elements = 1 - for i in range(rank): - total_elements *= dim_factors[i][-1] - return total_elements - - # Reduce the current chunk shape until the desired number of elements is - # reached. - while get_total_elements() > target_elements: - # Greedily reduce the largest dimension. This is not guaranteed to bring us - # the closest to `target_elements`, but is simple to implement and should - # work well enough. - dim_to_reduce = -1 - dim_to_reduce_size = 1 - for i in range(rank): - size = dim_factors[i][-1] - if size > dim_to_reduce_size: - dim_to_reduce_size = size - dim_to_reduce = i - # Can only fail to choose `dim_to_reduce` if all dimensions have size of 1. - # But that cannot happen since `target_elements >= 1`. - assert dim_to_reduce_size > 1 - dim_factors[dim_to_reduce].pop() - return [dim_factors[i][-1] for i in range(rank)] - - -@dataclasses.dataclass -class _ParameterInfo: - """Information needed to read/write and slice a partitioned parameter.""" - - # The unique parameter name. - name: str - # The shape of the parameter. - shape: Tuple[int] - # The TensoreStore Spec containing the minimal information for read/write. - ts_spec: Optional[ts.Spec] - # The LocalChunkInfo for the part of the parameter local to this host. - local_chunk_info: Optional[partitioning.LocalChunkInfo] - # PartitionSpec mesh axes - axes: Optional[partitioning.PartitionSpec] = None - - -def register_ts_spec_for_serialization(): - # Register functions with flax.serialization to handle `ts.Spec`. - def is_dict(s): - return isinstance(s, (dict, flax.core.FrozenDict)) - - serialization.register_serialization_state( - ts.Spec, - ty_to_state_dict=lambda t: t.to_json(), - # The parameter may have been written to tensorstore or msgpack. - # If the former, a dict of the spec will be stored. If the latter it will - # be the value itself. - ty_from_state_dict=lambda t, s: ts.Spec(s) if is_dict(s) else s, - override=True, - ) - - -register_ts_spec_for_serialization() - - -def _run_future_tree(future_tree): - """Block until all futures are resolved on this host.""" - future_leaves, treedef = jax.tree_util.tree_flatten(future_tree) - - async def run(): - return await asyncio.gather(*future_leaves) - - leaves = asyncio.run(run()) - return jax.tree_util.tree_unflatten(treedef, leaves) - - -def get_local_data(x): - """Get local buffer for input data.""" - if isinstance(x, jax.Array) and not isinstance(x, jax.core.Tracer): - return x.addressable_data(0) - else: - return x - - -def _sync_global_devices(name: str) -> None: - """Sync across all hosts/devices.""" - # Internal mock TPU handling - multihost_utils.sync_global_devices(name) - - -def _cast(target: PyTree, dtype: jnp.dtype): - """Cast arrays in target to dtype.""" - - def maybe_cast(x): - if isinstance(x, (int, str)): - # Ignore common non-array types that shouldn't be cast. - return x - elif x.dtype == dtype: - return x - elif isinstance(x, jax.ShapeDtypeStruct): - return jax.ShapeDtypeStruct(x.shape, dtype) - else: - return x.astype(dtype) - - return jax.tree_util.tree_map(maybe_cast, target) - - -def _update_ts_path_from_relative_to_absolute( - ckpt_dir: str, ts_spec_dict: MutableMapping[str, Any] -): - """Update (in-place) the path and gcs bucket (if applicable) in a TS Spec.""" - - # Handle `gs://` paths. - m = re.fullmatch('^gs://([^/]*)/(.*)$', ckpt_dir, re.DOTALL) - if m is not None: - if ts_spec_dict['kvstore']['driver'] != 'gcs': - raise ValueError( - 'Incorrect TensorStore Spec. ' - f'Expects kvstore driver to be "gcs" for {ckpt_dir}. ' - f'Got {ts_spec_dict}' - ) - bucket = m.group(1) - ckpt_dir = m.group(2) - ts_spec_dict['kvstore']['bucket'] = bucket - - # Update the path with `ckpt_dir` - - if 'path' in ts_spec_dict['kvstore']: - # tensorstore>=0.1.14 format - ts_spec_dict['kvstore']['path'] = os.path.join( - ckpt_dir, ts_spec_dict['kvstore']['path'] - ) - elif 'path' in ts_spec_dict: - # tensorstore<0.1.14 format - ts_spec_dict['path'] = os.path.join(ckpt_dir, ts_spec_dict['path']) - else: - raise ValueError( - 'Incorrect TensorStore Spec. Expects "path" to be a key of spec or ' - f'`spec["kvstore"]`. Got {ts_spec_dict}' - ) - - -def _maybe_update_ts_from_file_to_gcs(ckpt_contents): - """Updates the TensorStore driver from gfile to gcs.""" - - def _gfile_to_gcs_driver(arr_or_ts_spec_dict): - """Converts the ts.Spec dict using gfile driver to gcs driver.""" - if not isinstance(arr_or_ts_spec_dict, dict): - return arr_or_ts_spec_dict - - if arr_or_ts_spec_dict['kvstore']['driver'] in ('file', 'gfile'): - ts_spec_dict = arr_or_ts_spec_dict - path = ts_spec_dict['kvstore'].pop('path') - # This will be updated to the actual bucket in `_read_ts`. - ts_spec_dict['kvstore'] = { - 'bucket': 't5x-dummy-bucket', - 'driver': 'gcs', - 'path': path, - } - else: - if arr_or_ts_spec_dict['kvstore']['driver'] != 'gcs': - raise ValueError( - 'Unsupported TensoreStore driver. Got ' - f'{arr_or_ts_spec_dict["kvstore"]["driver"]}.' - ) - ts_spec_dict = arr_or_ts_spec_dict - - return ts_spec_dict - - def _is_leaf(value): - return not isinstance(value, dict) or set(value.keys()) >= { - 'driver', - 'kvstore', - 'metadata', - } - - return jax.tree_util.tree_map( - _gfile_to_gcs_driver, ckpt_contents, is_leaf=_is_leaf - ) - - -def _maybe_update_ts_from_gcs_to_file(ckpt_contents): - """Updates the TensorStore driver to gfile or file if different.""" - - # if saved in gcs, change to file - def _gcs_to_file_driver(arr_or_ts_spec_dict): - if not isinstance(arr_or_ts_spec_dict, dict): - return arr_or_ts_spec_dict - - if arr_or_ts_spec_dict['kvstore']['driver'] == 'gcs': - ts_spec_dict = arr_or_ts_spec_dict - path = ts_spec_dict['kvstore'].pop('path') - driver = 'file' - ts_spec_dict['kvstore'] = {'path': path, 'driver': driver} - elif arr_or_ts_spec_dict['kvstore']['driver'] == 'gfile': - ts_spec_dict = arr_or_ts_spec_dict - driver = 'file' - ts_spec_dict['kvstore']['driver'] = driver - elif arr_or_ts_spec_dict['kvstore']['driver'] == 'file': - ts_spec_dict = arr_or_ts_spec_dict - else: - raise ValueError( - 'Unsupported TensoreStore driver. Got ' - f'{arr_or_ts_spec_dict["kvstore"]["driver"]}.' - ) - - return ts_spec_dict - - def _is_leaf(value): - return not isinstance(value, dict) or set(value.keys()) >= { - 'driver', - 'kvstore', - 'metadata', - } - - return jax.tree_util.tree_map( - _gcs_to_file_driver, ckpt_contents, is_leaf=_is_leaf - ) - - -def _get_spec( - directory: str, arr: Any, name: str, metadata: Dict[str, Any] -) -> ts.Spec: - """Get ts.Spec from array and name information.""" - - if os.fspath(directory).startswith('gs://'): - spec = { - 'driver': 'zarr', - 'dtype': jnp.dtype(arr.dtype).name, - 'kvstore': { - 'driver': 'gcs', - # We always write with a dummy bucket and dynamically update the - # bucket information. This makes the checkpoint files portable - # and not bind to the bucket that it was originally written to. - 'bucket': 't5x-dummy-bucket', - }, - 'path': name.replace('/', '.'), - 'metadata': metadata, - } - else: - spec = { - 'driver': 'zarr', - 'dtype': jnp.dtype(arr.dtype).name, - 'kvstore': {'driver': 'file', 'path': name.replace('/', '.')}, - 'metadata': metadata, - } - - return ts.Spec(spec) - - -def _sharding_matches(arr: Any, target_sharding: jax.sharding.Sharding) -> bool: - if not isinstance(arr, jax.Array): - return False - sharding = arr.sharding - return sharding.is_equivalent_to(target_sharding, arr.ndim) - - -def _maybe_make_sharded_array( - arr: Any, - mesh: Optional[jax.sharding.Mesh], - axes: Optional[PartitionSpec] = None, - restore_dtype: Optional[jnp.dtype] = None, - params_on_devices: bool = True, -) -> Any: - """Makes a sharded array from non-sharded array if necessary. - - Args: - arr: array to maybe shard. - mesh: jax.sharding.Mesh. - axes: mesh_axes. - restore_dtype: type to restore as. - params_on_devices: If true, the array will be placed on device. Otherwise, - it will be stored in the host(s) RAM. - - Returns: - Sharded or unsharded array. - """ - if axes is None: - axes = PartitionSpec(None) - assert mesh is not None, 'jax.sharding.Mesh should be provided.' - target_sharding = jax.sharding.NamedSharding(mesh, axes) - if _sharding_matches(arr, target_sharding): - return arr - if isinstance(arr, (np.ndarray, jnp.ndarray)): - if restore_dtype is not None: - arr = arr.astype(restore_dtype) - if not params_on_devices: - return arr - arr = jax.make_array_from_callback( - arr.shape, target_sharding, lambda idx: arr[idx] - ) - return arr - - -class _BytesConditionVariable(object): - """Wraps a condition variable to control concurrency based on bytes.""" - - def __init__(self, num_bytes): - self._max_bytes = num_bytes - self._num_bytes = num_bytes - self._cv = asyncio.Condition(lock=asyncio.Lock()) - - async def wait_for_bytes(self, n_bytes): - async with self._cv: - await self._cv.wait_for(lambda: self._num_bytes > n_bytes) - self._num_bytes -= n_bytes - assert self._num_bytes >= 0 - - async def return_bytes(self, n_bytes): - async with self._cv: - self._num_bytes += n_bytes - assert self._num_bytes <= self._max_bytes - self._cv.notify_all() - - -class SaveStateTransformationFn(typing_extensions.Protocol): - - def __call__( - self, state_dict: PyTree, parameter_infos: PyTree - ) -> Tuple[PyTree, PyTree]: - """Transforms the state and param info, e.g., by remapping parameters. - - Args: - state_dict: State in the current model. - parameter_infos: PyTree containing `_ParameterInfo` objects. - - Returns: - A tuple whose first element is the result of transforming `state_dict` and - whose second element is the result of transforming `parameter_infos`. - """ - - -class RestoreStateTransformationFn(typing_extensions.Protocol): - - def __call__( - self, - state_dict: PyTree, - target_state_dict: PyTree, - *, - is_resuming: bool = False, - ) -> PyTree: - """Transforms the given checkpoint state, e.g., by remapping parameters. - - Args: - state_dict: State to transform, which could be from a previous version of - the model. - target_state_dict: State in the current model. - is_resuming: `True` iff this restore call is due to a job resuming after - being temporarily stopped due to, for example, a preemption. This is - useful when there is restore logic that should run when restoring from - some pre-existing checkpoint, but that should not run again when - resuming from a newly-written checkpoint. - - Returns: - The result of transforming the `state_dict`. - """ - - -class _TfDataCheckpointer: - - def __init__(self, dataset_iterator: tf.data.Iterator): - self._dataset_ckpt = tf.train.Checkpoint(ds=dataset_iterator) - - def save(self, filename: str): - self._dataset_ckpt.write(filename) - - def load(self, filename: str): - self._dataset_ckpt.read(filename).assert_consumed() - - -# TODO(b/216649487): Replace with CheckpointManager. -class Checkpointer(object): - """Handles saving and restoring potentially-sharded T5X checkpoints. - - Checkpoints are stored using a combination of msgpack (via flax.serialization) - and TensorStore. - - Parameters (and other objects) that are not partitioned are written to the - msgpack binary directly (by host 0). Partitioned parameters are each written - to their own TensorStore, with each host writing their portion to the same - TensorStore in parallel. If a partition is written on multiple hosts, the - partition is further sharded across these replicas to avoid additional - overhead. In place of the parameter, a `tensorstore.Spec` is written to the - msgpack (by host 0) as a reference to be used during restore. Note that the - path of the array being written is relative. This makes the checkpoints - portable. In other words, even if the checkpoint files are moved to a new - directory, they can still be loaded. Because the path is relative, the - checkpoint directory information has to be dynamically provided. This is done - by `_update_ts_path_from_relative_to_absolute`. - - For TensorStore driver using Google Cloud Storage (GCS) Key-Value Storage - Layer, the GCS bucket information is necessary. When a checkpoint is written - using the gcs driver, we don't want to hardcode the bucket information in the - resulting file in order to maintain the portability. Therefore, we use a dummy - bucket name of "t5x-dummy-bucket". When reading or writing the checkpoint, the - bucket information is parsed from the checkpoint directory and the bucket - information is dynamically updated. - - Attributes: - checkpoints_dir: a path to a directory to save checkpoints in and restore - them from. - keep: an optional maximum number of checkpoints to keep. If more than this - number of checkpoints exist after a save, the oldest ones will be - automatically deleted to save space. - restore_dtype: optional dtype to cast targets to after restoring. - save_dtype: dtype to cast targets to before saving. - keep_dataset_checkpoints: an optional maximum number of data iterators to - keep. If more than this number of data iterators exist after a save, the - oldest ones will be automatically deleted to save space. - """ - - def __init__( # pytype: disable=annotation-type-mismatch # jnp-type - self, - train_state: train_state_lib.TrainState, - partitioner: partitioning.BasePartitioner, - checkpoints_dir: epath.PathLike, - dataset_iterator: Optional[ - Union[tf.data.Iterator, clu.data.dataset_iterator.DatasetIterator] - ] = None, - *, - keep: Optional[int] = None, - save_dtype: jnp.dtype = np.float32, - restore_dtype: Optional[jnp.dtype] = None, - keep_dataset_checkpoints: Optional[int] = None, - ): - """Checkpointer constructor. - - Args: - train_state: A train state to be used to determine the structure of the - parameter tree, and the *full* (non-partitioned) parameter shapes and - dtypes. Saved and restored train states must match this structure. - partitioner: the partitioner to use for determining the local chunks - mapping or to perform params partitioning on restore. - checkpoints_dir: a path to a directory to save checkpoints in and restore - them from. - dataset_iterator: an optional iterator to save/restore. - keep: an optional maximum number of checkpoints to keep. If more than this - number of checkpoints exist after a save, the oldest ones will be - automatically deleted to save space. - save_dtype: dtype to cast targets to before saving. - restore_dtype: optional dtype to cast targets to after restoring. If None, - no parameter casting is performed. - keep_dataset_checkpoints: an optional maximum number of data iterators to - keep. If more than this number of data iterators exist after a save, the - oldest ones will be automatically deleted to save space. - """ - self._train_state = train_state - self._partitioner = partitioner - self.checkpoints_dir = checkpoints_dir - self.keep = keep - self.keep_dataset_checkpoints = keep_dataset_checkpoints - # Immutable due to use in `_get_parameter_infos` - self._save_dtype = save_dtype - self.restore_dtype = restore_dtype - self._original_dataset_iterator = dataset_iterator - if isinstance(dataset_iterator, tf.data.Iterator): - dataset_iterator = _TfDataCheckpointer(dataset_iterator) - elif isinstance( - dataset_iterator, clu.data.dataset_iterator.TfDatasetIterator - ): - assert dataset_iterator._checkpoint - self._dataset_iterator = dataset_iterator - - data_layout = partitioner.get_data_layout() - self._dataset_ckpt_name = ( - f'{_TRAIN_DS_PREFIX}-' - f'{data_layout.shard_id:03}-of-{data_layout.num_shards:03}' - ) - self._should_write_dataset_ckpt = ( - dataset_iterator and data_layout.is_first_host_in_replica_set - ) - - self._parameter_infos = self._get_parameter_infos() - - asyncio.set_event_loop(asyncio.new_event_loop()) - - def _get_state_dict_for_save( - self, state_dict: Dict[str, Any], lazy_load: bool = True - ) -> MutableMapping[str, Any]: - """Gets the optimizer state dict.""" - - def _lazy_load_device_array(arr): - if isinstance(arr, jax.Array): - if len(arr.sharding.device_set) == 1: - return LazyThreadPoolArray( - arr.shape, arr.dtype, lambda: np.array(arr) - ) - return arr - - if lazy_load: - state_dict = jax.tree_util.tree_map(_lazy_load_device_array, state_dict) - return state_dict - - def _get_parameter_infos(self): - """Generates the state dict of _ParameterInfos for the Optimizer. - - We generate a state dict (matching the shape of the optimizer state dict) - that stores a _ParameterInfo for each parameter array. - - The _ParameterInfo contains the TensorStore spec for the parameter array and - the LocalChunkInfo describing the slice of the array local to this host. - - Returns: - The state dict of _ParameterInfo objects. - """ - - def _get_param_info(name: str, arr: Any, axes: partitioning.PartitionSpec): - # If a node in your model is None it is probably a param_state that is not - # used because of a MultiOptimizer. We don't want to have any parameter - # info for it because it shouldn't be saved or restored. - if arr is None: - return None - # Pass-through empty dict leaves, which occur with optax EmptyState(). - if isinstance(arr, dict) and not arr: - return {} - - if axes is None: - return _ParameterInfo( - name=name, - shape=arr.shape, - ts_spec=None, - local_chunk_info=None, - axes=None, - ) - - if isinstance(arr, jax.Array): - local_chunk_info = None - metadata = array_serialization._get_metadata(arr) # pylint: disable=protected-access - else: - local_chunk_info = self._partitioner.get_local_chunk_info( - arr.shape, axes - ) - write_shape = [ - si if sl == slice(None) else sl.stop - sl.start - for si, sl in zip(arr.shape, local_chunk_info.slice) - ] - # TODO(levskaya, adarob): how should we handle stacked/fused variables?? - chunk_shape = _choose_chunk_shape( - write_shape, - target_elements=_DESIRED_CHUNK_SIZE_BYTES / arr.dtype.itemsize, - ) - - metadata = { - 'compressor': {'id': 'gzip'}, - 'shape': arr.shape, - 'chunks': np.array(chunk_shape), - } - - spec = _get_spec(self.checkpoints_dir, arr, name, metadata) - - return _ParameterInfo( - name, - shape=arr.shape, - ts_spec=spec, - local_chunk_info=local_chunk_info, - axes=axes, - ) - - # Create a tree of param names as the keys on the path to each leaf - # separated by "/". - param_names = traverse_util.unflatten_dict({ - k: '/'.join(k) - for k in traverse_util.flatten_dict( - self._train_state.state_dict(), keep_empty_nodes=True - ) - }) - - return jax.tree_util.tree_map( - _get_param_info, - param_names, - self._get_state_dict_for_save(self._train_state.state_dict()), - self._partitioner.get_mesh_axes(self._train_state).state_dict(), - ) - - def _get_checkpoint_dir(self, step: int) -> epath.PathLike: - return get_checkpoint_dir(self.checkpoints_dir, step) - - def all_steps(self) -> Sequence[int]: - """Returns list of available step numbers in ascending order.""" - return all_steps(self.checkpoints_dir) - - def all_dataset_checkpoint_steps(self) -> Sequence[int]: - """Returns list of available step numbers in ascending order.""" - return all_dataset_checkpoint_steps(self.checkpoints_dir) - - def latest_step(self) -> Optional[int]: - """Returns latest step number or None if no checkpoints exist.""" - return latest_step(self.checkpoints_dir) - - def _remove_old_dataset_checkpoints(self): - """Deletes old dataset checkpoints if there are more than allowed.""" - if self.keep_dataset_checkpoints: - existing_steps = self.all_dataset_checkpoint_steps() - to_remove = len(existing_steps) - self.keep_dataset_checkpoints - if to_remove > 0: - for step in existing_steps[:to_remove]: - checkpoint_utils.remove_dataset_checkpoint( - self._get_checkpoint_dir(step), _TRAIN_DS_PREFIX - ) - - def _remove_old_checkpoints(self): - """Deletes oldest checkpoints if there are more than keep_checkpoints.""" - if not self.keep: - return - existing_steps = self.all_steps() - to_remove = len(existing_steps) - self.keep - if to_remove <= 0: - return - - for step in existing_steps[:to_remove]: - checkpoint_utils.remove_checkpoint_dir(self._get_checkpoint_dir(step)) - - def save( - self, - train_state: train_state_lib.TrainState, - state_transformation_fns: Sequence[SaveStateTransformationFn] = (), - *, - concurrent_gb: int = 128, - ): - """Saves a checkpoint for the given train state. - - Args: - train_state: the train state to save. May contain a combination of - LazyArray objects and arrays (e.g., np.ndarray, jax.DeviceArray) - state_transformation_fns: Transformations to apply, in order, to the state - before writing. - concurrent_gb: the approximate number of gigabytes of partitionable - parameters to process in parallel. Useful to preserve RAM. - """ - start_time = time.time() - step = train_state.step - step = step.get() if isinstance(step, LazyArray) else step - step = get_local_data(step) - # Integer, to avoid side effects in the checkpoint path. - step = int(step) - - # Share a timestamp across devices. - timestamp = multihost_utils.broadcast_one_to_all(np.int32(time.time())) - - final_dir = os.path.join( - self.checkpoints_dir, f'{get_checkpoint_prefix()}_{step}' - ) - tmp_dir = final_dir + f'.tmp-{timestamp}' - - if gfile.exists(final_dir): - logging.info( - 'Skipping save checkpoint for step %d (directory %s already exists)', - step, - final_dir, - ) - return - - logging.info('Saving checkpoint for step %d to %s', step, tmp_dir) - - if jax.process_index() == 0: - gfile.makedirs(tmp_dir) - # Block all hosts until directory is ready. - _sync_global_devices(f'checkpointer:make_dir:{tmp_dir}') - - written_state_dict = self._write_state_to_tensorstore( - tmp_dir, train_state, concurrent_gb, state_transformation_fns - ) - - if self._should_write_dataset_ckpt: - logging.info( - "Writing dataset iterator state to '%s'.", self._dataset_ckpt_name - ) - try: - self._dataset_iterator.save( - os.path.join(tmp_dir, self._dataset_ckpt_name) - ) - except tf.errors.FailedPreconditionError as e: - logging.error( - 'Input pipeline must be stateless in order to checkpoint. Cache ' - 'stateful steps offline or disable iterator checkpointing.' - ) - raise e - - # Block until complete on all hosts. - _sync_global_devices(f'checkpointer:tensorstore_write_complete:{tmp_dir}') - - if jax.process_index() == 0: - written_state_dict = jax.tree_util.tree_map( - get_local_data, written_state_dict - ) - - # Write msgpack file in host 0 only - msgpack_bytes = serialization.to_bytes( - {'version': VERSION, 'optimizer': written_state_dict} - ) - with gfile.GFile(os.path.join(tmp_dir, 'checkpoint'), 'wb') as fp: - fp.write(msgpack_bytes) - - # Finalize checkpoint directory. - if final_dir.startswith('gs://'): - subprocess.run( - ['gsutil', '-m', 'mv', tmp_dir, final_dir], - stdout=subprocess.DEVNULL, - check=True, - ) - else: - gfile.rename(tmp_dir, final_dir) - logging.info('Saved checkpoint for step %d to %s', step, final_dir) - - # Remove old checkpoints, if necessary. - self._remove_old_checkpoints() - self._remove_old_dataset_checkpoints() - - # Block until complete on all hosts. - _sync_global_devices(f'checkpointer:write_complete:{final_dir}') - - end_time = time.time() - monitoring.record_event_duration_secs( - _WRITE_CHECKPOINT_EVENT, end_time - start_time - ) - ocp.utils.record_saved_duration(start_time) - - def _write_state_to_tensorstore( - self, - ckpt_dir: str, - train_state: train_state_lib.TrainState, - concurrent_gb: int, - state_transformation_fns: Sequence[SaveStateTransformationFn], - ) -> Mapping[str, Any]: - """Writes extracted state from train state to Tensorstore.""" - concurrent_bytes = concurrent_gb * 10**9 - - async def _write_array( - maybe_arr: Any, param_info: Optional[_ParameterInfo], cast: bool = False - ): - """Maybe write to TensorStore, returning object to write to msgpack. - - Args: - maybe_arr: array or LazyArray to be written - param_info: ParameterInfo object. If None (or if param_info.ts_spec is - None), the array will be immediately returned without writing to - tensorstore. This is because array is None or is not partitioned, and - should be written separately. - cast: if True, performs cast operation using self._save_dtype. - - Returns: - Tensorstore spec corresponding to the written array. - """ - bytes_cv = _BytesConditionVariable(concurrent_bytes) - - if isinstance(maybe_arr, LazyArray): - maybe_arr = await maybe_arr.get_async() - - if param_info is None or param_info.ts_spec is None: - # Write to the msgpack file on host 0. - return maybe_arr - - arr = maybe_arr - # Wait until memory is available. - if isinstance(arr, jax.Array): - n_bytes = sum([ - shard.data.nbytes - for shard in arr.addressable_shards - if shard.replica_id == 0 - ]) - else: - n_bytes = arr.nbytes - if n_bytes > concurrent_bytes: - logging.warning( - ( - 'Temporarily increasing the concurrency limits from %d bytes to' - ' %d bytes to fit %s.' - ), - concurrent_bytes, - n_bytes, - param_info.name, - ) - n_bytes = concurrent_bytes - await bytes_cv.wait_for_bytes(n_bytes) - - tmp_ts_spec_dict = param_info.ts_spec.to_json() - if cast: - # Set desired destination dtype. - tmp_ts_spec_dict['dtype'] = jnp.dtype(self._save_dtype).name - param_info.ts_spec = ts.Spec(tmp_ts_spec_dict) - # Path and gcs bucket (if applicable) information is updated in-place. - _update_ts_path_from_relative_to_absolute(ckpt_dir, tmp_ts_spec_dict) - if cast: - # Set up casting spec. - tmp_ts_spec_dict = { - 'base': tmp_ts_spec_dict, - 'driver': 'cast', - 'dtype': jnp.dtype(arr.dtype).name, # dtype before cast - } - - if isinstance(arr, jax.Array): - await array_serialization.async_serialize(arr, tmp_ts_spec_dict) - else: - # Array is assumed to be replicated on all hosts in this case. - t = await ts.open( - tmp_ts_spec_dict, - create=True, - open=True, - context=ts.Context({'file_io_concurrency': {'limit': 128}}), - ) - await t.write(arr) - await bytes_cv.return_bytes(n_bytes) - - # N.B. we return the original ts_spec (before - # `_update_ts_path_from_relative_to_absolute` was called). This is because - # we'd like to keep the path as relative, i.e., it doesn't hardcode the - # directory that the checkpoint was originally written. This makes the - # checkpoints portable. - return param_info.ts_spec - - transformed_state_dict, transformed_parameter_infos = ( - _transform_state_and_infos( - train_state.state_dict(), - self._parameter_infos, - state_transformation_fns, - ) - ) - - state_dict_for_save = self._get_state_dict_for_save(transformed_state_dict) - - def _cast_arr_if_not_partitioned(maybe_arr, param_info): - if param_info is None or param_info.ts_spec is None: - return _cast(maybe_arr, self._save_dtype) - return maybe_arr - - state_dict_for_save['target'] = jax.tree_util.tree_map( # pytype: disable=unsupported-operands # dynamic-method-lookup - _cast_arr_if_not_partitioned, - state_dict_for_save['target'], - transformed_parameter_infos['target'], - ) - future_written_state = {} - for k in state_dict_for_save.keys(): - # ensure that only 'target' is cast - future_written_state[k] = jax.tree_util.tree_map( - functools.partial( - _write_array, - cast=(k == 'target' and self._save_dtype is not None), - ), - state_dict_for_save[k], - transformed_parameter_infos[k], - ) - - # Block until complete on this host. - written_state_dict = _run_future_tree(future_written_state) - - # Block until complete on all hosts. - _sync_global_devices(f'checkpointer:ts_write_complete:{ckpt_dir}') - - return written_state_dict - - def _transform_state_and_infos( - self, - state_dict: PyTree, - parameter_infos: PyTree, - state_transformation_fns: Sequence[SaveStateTransformationFn], - ) -> Tuple[PyTree, PyTree]: - """Applies transformations to the state dict and parameter infos PyTrees.""" - return _transform_state_and_infos( - state_dict, parameter_infos, state_transformation_fns - ) - - def restore( - self, - step: Optional[int] = None, - path: Optional[str] = None, - state_transformation_fns: Sequence[RestoreStateTransformationFn] = (), - fallback_state: Optional[Mapping[str, Any]] = None, - lazy_parameters: bool = False, - ) -> train_state_lib.TrainState: - """Restores the host-specific parameters in an Optimizer. - - Either `step` or `path` can be specified, but not both. If neither are - specified, restores from the latest checkpoint in the checkpoints directory. - - Args: - step: the optional step number to restore from. - path: an optional absolute path to a checkpoint file to restore from. - state_transformation_fns: Transformations to apply, in order, to the state - after reading. - fallback_state: a state dict of an optimizer to fall back to for loading - params that do not exist in the checkpoint (after applying all - `state_transformation_fns`), but do exist in `Checkpointer.optimizer`. - The union of `fallback_state` and state loaded from the checkpoint must - match `Checkpointer.optimizer`. - lazy_parameters: whether to load the parameters as LazyArrays to preserve - memory. - - Returns: - The restored train state. - - Raises: - ValueError if both `step` and `path` are specified. - ValueError if checkpoint at `path` or `step` does not exist. - ValueError if `step` and `path` are not specified and no checkpoint is - found in the checkpoints directory. - """ - start_time = time.time() - if lazy_parameters and self._partitioner.params_on_devices: - raise ValueError( - 'Lazy Parameters cannot be copied to devices, please ' - 'set partitioner.params_on_devices=False.' - ) - if step is not None and path is not None: - raise ValueError('At most one of `step` or `path` may be provided.') - if path: - ckpt_path = path - else: - if step is None: - step = self.latest_step() - if not step: - raise ValueError(f'No checkpoints found in {self.checkpoints_dir}.') - ckpt_path = self._get_checkpoint_dir(step) - - if gfile.isdir(ckpt_path): - ckpt_dir = ckpt_path - if gfile.isdir(os.path.join(ckpt_dir, _STATE_KEY)): - ckpt_path = os.path.join(ckpt_path, _STATE_KEY, 'checkpoint') - else: - ckpt_path = os.path.join(ckpt_path, 'checkpoint') - else: - ckpt_dir = os.path.dirname(ckpt_path) - - if not gfile.exists(ckpt_path) or gfile.isdir(ckpt_path): - raise ValueError(f'Path is not a valid T5X checkpoint: {ckpt_path}') - - ckpt_type = checkpoint_utils.detect_checkpoint_type( - ckpt_path, expected=checkpoint_utils.CheckpointTypes.T5X - ) - if ckpt_type is checkpoint_utils.CheckpointTypes.T5X_TF: - raise ValueError( - 'Attempting to restore a TensorFlow checkpoint as a native T5X ' - 'checkpoint. Use `restore_from_tf_checkpoint` instead. Path: ' - f'{ckpt_path}' - ) - # Don't error out here for Orbax-detected checkpoint (there are edge cases - # where all values are stored in the msgpack and the checkpoint file can be - # loaded by both the Orbax and T5X checkpointer). - logging.info('Restoring from checkpoint: %s', ckpt_path) - - with gfile.GFile(ckpt_path, 'rb') as fp: - # TODO(adarob): Use threaded reading as in flax.checkpoints. - # `ckpt_contents['optimizer']` is a pytree with a realized np.array for - # leaves (params or states) written as msgpack and a ts.Spec (in a dict) - # for leaves written by TensorStore. - ckpt_contents = serialization.msgpack_restore(fp.read()) - - # If reading a ckpt that was written with gfile driver but the current - # session uses the gcs driver, convert the ckpt's driver to gcs. - if os.fspath(ckpt_dir).startswith('gs://'): - ckpt_contents = _maybe_update_ts_from_file_to_gcs(ckpt_contents) - # If a ckpt was saved in gcs and is being loaded locally, then convert the - # driver to file or gfile. If the ckpt was not saved in gcs, do not change. - else: - ckpt_contents = _maybe_update_ts_from_gcs_to_file(ckpt_contents) - - ckpt_state_dict = self._get_optimizer_state_dict( - ckpt_contents, - state_transformation_fns, - use_orbax_format=ckpt_type is checkpoint_utils.CheckpointTypes.ORBAX, - ) - - # The state dict may contain TensorStore specs that need to be read. - dummy_spec = ts.Spec({'driver': 'zarr', 'kvstore': {'driver': 'memory'}}) - - # `dummy_written_state_dict` is a pytree with a `dummy_spec` for leaves - # (params or states) written as msgpack and a ts.Spec (in a dict) for leaves - # written by TensorStore. - dummy_written_state_dict = jax.tree_util.tree_map( - lambda x: x.ts_spec or dummy_spec, - self._parameter_infos, - ) - - if fallback_state is None: - restore_parameter_infos = self._parameter_infos - else: - # If `fallback_state` was specified, restore only the subset - # of parameters matched by `self._get_optimizer_state_dict`. The - # rest will be provided by `fallback_state`. - dummy_written_state_dict = state_utils.intersect_state( - dummy_written_state_dict, ckpt_state_dict - ) - restore_parameter_infos = state_utils.intersect_state( - self._parameter_infos, ckpt_state_dict - ) - - restore_parameter_infos_flat = state_utils.flatten_state_dict( - restore_parameter_infos - ) - for key in restore_parameter_infos_flat.keys(): - logging.info('Restoring key from ckpt: %s', key) - - # NB: `serialization.from_state_dict` doesn't check whether the shapes match - # at the leaf level. Non-partitioned leaves (e.g., optimizer states) can - # load arrays with inconsistent shapes. - # `written_state_dict` is a pytree with a realized np.array for leaves - # (params or states) written as msgpack and a `ts.Spec` for leaves written - # by TensorStore. - written_state_dict = serialization.from_state_dict( - dummy_written_state_dict, ckpt_state_dict - ) - state_dict = self._read_state_from_tensorstore( - ckpt_path, - written_state_dict, - restore_parameter_infos=restore_parameter_infos, - lazy_parameters=lazy_parameters, - ) - - # If `fallback_state` was specified, then fill the missing parameters. - if fallback_state is not None: - state_dict = state_utils.merge_state(state_dict, fallback_state) - - for key in state_utils.flatten_state_dict(state_dict).keys(): - if key not in restore_parameter_infos_flat: - logging.info('Not restoring key from ckpt: %s', key) - - if self._dataset_iterator: - logging.info( - "Restoring dataset iterator from '%s'.", self._dataset_ckpt_name - ) - self._dataset_iterator.load( - os.path.join(ckpt_dir, self._dataset_ckpt_name) - ) - - restored_train_state = self._restore_train_state(state_dict) - - end_time = time.time() - monitoring.record_event_duration_secs( - _READ_CHECKPOINT_EVENT, end_time - start_time - ) - return restored_train_state - - def _restore_train_state( - self, state_dict: optimizers.OptimizerStateType - ) -> train_state_lib.TrainState: - """Restores a TrainState from an Optimizer state_dict.""" - return self._train_state.restore_state(state_dict) - - def _create_lazy_awaitable_array( - self, - param_info: _ParameterInfo, - maybe_ts_spec: Any, - ckpt_path: str, - restore_dtype: Optional[jnp.dtype], - ) -> LazyAwaitableArray: - """Creates LazyArray from tensorstore. - - Does not materialize the array immediately. - - Args: - param_info: Information about how to read the parameter, host based sliced - reads and the like. - maybe_ts_spec: The tensorstore spec to read the parameter or some other - object. If this is an array then we will do a host based sliced read on - it (provided the param_info says to). Anything else we just return. - ckpt_path: A base location to use when resolving the relative paths in the - tensorstore spec. - restore_dtype: type to restore as. None indicates that no cast is - requested. - - Returns: - LazyArray object. - """ - mesh = self._partitioner.mesh - axes = param_info.axes - - async def get_fn(): - nonlocal mesh - nonlocal axes - arr = await _read_ts( - param_info, - maybe_ts_spec, - ckpt_path=ckpt_path, - restore_dtype=restore_dtype, - mesh=mesh, - axes=axes, - params_on_devices=self._partitioner.params_on_devices, - ) - return _maybe_make_sharded_array( - arr, - mesh, - axes=axes, - restore_dtype=restore_dtype, - params_on_devices=self._partitioner.params_on_devices, - ) - - return LazyAwaitableArray.from_tensor_store_spec_or_array( - maybe_ts_spec, get_fn, dtype=restore_dtype - ) - - def _read_state_from_tensorstore( - self, - ckpt_path: str, - written_state_dict: Mapping[str, Any], - restore_parameter_infos: Optional[Mapping[str, Any]] = None, - lazy_parameters: bool = False, - ) -> Mapping[str, Any]: - """Sets up lazy reads from Tensorstore and returns them as a state_dict.""" - if restore_parameter_infos is None: - restore_parameter_infos = self._parameter_infos - - # Replace TensorStore Specs with the lazy array values. - state_dict = {} - for k in written_state_dict.keys(): - # ensure that only 'target' is cast - restore_dtype = self.restore_dtype if k == 'target' else None - state_dict[k] = jax.tree_util.tree_map( - functools.partial( - self._create_lazy_awaitable_array, - ckpt_path=ckpt_path, - restore_dtype=restore_dtype, - ), - restore_parameter_infos[k], - written_state_dict[k], - ) - - if not lazy_parameters: - future_state_dict = jax.tree_util.tree_map( - lambda x: x.get_async(), state_dict - ) - state_dict = _run_future_tree(future_state_dict) - - if self.restore_dtype is not None: - if 'target' not in state_dict: - raise ValueError( - f'restore_dtype={self.restore_dtype} was specified, but no `target`' - ' parameters were loaded.' - ) - state_dict['target'] = _cast(state_dict['target'], self.restore_dtype) - - return state_dict - - def restore_from_tf_checkpoint( - self, - path_or_dir: str, - strict: bool = True, - translator: Optional[checkpoint_importer.CheckpointTranslator] = None, - ) -> train_state_lib.TrainState: - """Restore from a TensorFlow-based T5 checkpoint.""" - start_time = time.time() - full_state_dict = checkpoint_importer.restore_from_t5_checkpoint( - self._train_state.state_dict(), - path_or_dir, - lazy_parameters=False, - strict=strict, - translator=translator, - ) - full_state_dict = dict(full_state_dict) - - def _partition_parameter(maybe_arr: Any, param_info: _ParameterInfo): - if isinstance(maybe_arr, np.ndarray) and param_info: - arr = maybe_arr - to_sharded_array = self._partitioner.partition( - lambda x: x, - in_axis_resources=None, - out_axis_resources=param_info.axes, - ) - return to_sharded_array(arr) - return maybe_arr - - if self.restore_dtype is not None: - full_state_dict['target'] = _cast( - full_state_dict['target'], self.restore_dtype - ) - state_dict = jax.tree_util.tree_map( - _partition_parameter, full_state_dict, self._parameter_infos - ) - - restored_train_state = self._restore_train_state(state_dict) - - end_time = time.time() - monitoring.record_event_duration_secs( - _READ_CHECKPOINT_EVENT, end_time - start_time - ) - - return restored_train_state - - def convert_from_tf_checkpoint( - self, - path_or_dir: str, - *, - state_transformation_fns: Sequence[SaveStateTransformationFn] = (), - concurrent_gb: int = 16, - translator: Optional[checkpoint_importer.CheckpointTranslator] = None, - ): - """Convert from a TensorFlow-based T5 checkpoint.""" - full_state_dict = checkpoint_importer.restore_from_t5_checkpoint( - self._train_state.state_dict(), - path_or_dir, - lazy_parameters=True, - translator=translator, - ) - train_state = self._train_state.restore_state(full_state_dict) - self.save( - train_state, - state_transformation_fns=state_transformation_fns, - concurrent_gb=concurrent_gb, - ) - - def _get_optimizer_state_dict( - self, - ckpt_contents: PyTree, - state_transformation_fns: Sequence[RestoreStateTransformationFn], - use_orbax_format: bool = False, - ): - return _get_optimizer_state_dict( - ckpt_contents, - self._train_state.state_dict(), - state_transformation_fns, - use_orbax_format, - ) - - -class CheckpointerConstructor(typing_extensions.Protocol): - """A function that returns a checkpoints.Checkpointer. - - This type annotation allows users to partially bind args to the constructors - of Checkpointer subclasses without triggering type errors. - """ - - def __call__( - self, # pytype: disable=annotation-type-mismatch # jnp-type - train_state: train_state_lib.TrainState, - partitioner: partitioning.BasePartitioner, - checkpoints_dir: str, - dataset_iterator: Optional[tf.data.Iterator] = None, - *, - keep: Optional[int] = None, - save_dtype: jnp.dtype = np.float32, - restore_dtype: Optional[jnp.dtype] = None, - keep_dataset_checkpoints: Optional[int] = None, - ) -> Checkpointer: - """Checkpointer constructor. - - Args: - train_state: A train state to be used to determine the structure of the - parameter tree, and the *full* (non-partitioned) parameter shapes and - dtypes. Saved and restored train states must match this structure. - partitioner: the partitioner to use for determining the local chunks - mapping or to perform params partitioning on restore. - checkpoints_dir: a path to a directory to save checkpoints in and restore - them from. - dataset_iterator: an optional iterator to save/restore. - keep: an optional maximum number of checkpoints to keep. If more than this - number of checkpoints exist after a save, the oldest ones will be - automatically deleted to save space. - save_dtype: dtype to cast targets to before saving. - restore_dtype: optional dtype to cast targets to after restoring. If None, - no parameter casting is performed. - keep_dataset_checkpoints: an optional maximum number of data iterators to - keep. If more than this number of data iterators exist after a save, the - oldest ones will be automatically deleted to save space. - """ - pass - - -def populate_metrics_for_steps( - checkpoints_dir: str, metric_name: str, steps: Iterable[int] -) -> Mapping[int, float]: - """Iterate through summary event files and return metrics for `steps`.""" - - metric_run, metric_tag = None, None - - def _try_fill_metric_run_and_tag_names( - metric_name: str, run_keys: Iterable[str] - ) -> bool: - """Extract metric run and tag names by matching one of the `run_keys`. - - This function tries to greedily split user-provided metric_name_to_monitor - into {run} and {tag} components. It does so by trying to match all available - {run}/{tag} names in the provided run_keys. If successful, populates - metric_run and metric_tag. - - Args: - metric_name: metric name to monitor. - run_keys: Set of run keys to test for. - - Returns: - Whether metric name prefix matches one of the run keys, and, as a - side-effect, populates metric_run and metric_tag. - """ - nonlocal metric_run - nonlocal metric_tag - - # Query existing events for different run and tags to match with user - # provided metric name. - m = metric_name.split('/') - possible_run_names = ['/'.join(m[:i]) for i in range(1, len(m))] - for key in run_keys: - for possible_run_name in possible_run_names: - if key == possible_run_name: - metric_run = possible_run_name - metric_tag = metric_name[len(metric_run) + 1 :] - break - - if metric_run and metric_tag: - return True - return False - - metrics_by_step = {} - for subdir in io_wrapper.GetLogdirSubdirectories(checkpoints_dir): - rpath = os.path.relpath(subdir, checkpoints_dir) - # Skip runs that do not match user-specified metric. - if ( - not metric_run - and not _try_fill_metric_run_and_tag_names(metric_name, (rpath,)) - ) or metric_run != rpath: - logging.info('Skipping events in %s', subdir) - continue - - logging.info('Looking for events in %s', subdir) - loader = directory_watcher.DirectoryWatcher( - subdir, - event_file_loader.EventFileLoader, - io_wrapper.IsTensorFlowEventsFile, - ) - for event in loader.Load(): - # Skip metric collection of events for unavailable checkpoints or for - # unmonitored tags. - if ( - event.step not in steps - or not event.summary.value - or event.summary.value[0].tag != metric_tag - ): - continue - metric_value = tf.make_ndarray(event.summary.value[0].tensor) - metrics_by_step[event.step] = metric_value - - return metrics_by_step - - -# TODO(b/216649487): Replace with BestCheckpointManager. -@gin.configurable -class SaveBestCheckpointer(Checkpointer): - """A Checkpointer class that keeps checkpoints based on 'best' metrics. - - This extends the standard Checkpointer to garbage collect checkpoints based on - metric values, instead of step recency. It uses TensorBoard summary files to - determine best values for a given user configured metric name. Events are read - and parsed using TensorBoard's event_processing packages. - - The metric name must be of the form `{run_name}/{tag_name}`. For example, - 'train/accuracy' or 'inference_eval/glue_cola_v002/eval/accuracy'. - - A few important features of this checkpointer: - - - Fallback behavior. It is not possible to verify whether metric names are - valid during initialization, since some metrics may get written out after - some time (e.g., during an evaluation). As such, when user provided metric - names are not found, this checkpointer can be configured for two fall back - strategies: (1) if `keep_checkpoints_without_metrics` is False, we use to - the "most recent checkpoint" strategy from the standard checkpointer, (2) - if `keep_checkpoints_without_metrics` is True, we keep all checkpoints until - metrics become available (potentially indefinitely if summary files have - been deleted or corrupted). - - - The number of checkpoints to keep is always increased by 1. Since its - crucial to always keep the latest checkpoint (for recovery purposes) we - always store the latest checkpoint plus `keep` number of best checkpoints. - - - It is assumed that TensorBoard summaries (event) files share a common root - directory with `checkpoint_dir`, which is the directory passed to the - the logdir crawler that searches for event files. - - Attributes: - checkpoints_dir: a path to a directory to save checkpoints in and restore - them from. - keep: an optional maximum number of checkpoints to keep. If more than this - number of checkpoints exist after a save, the oldest ones will be - automatically deleted to save space. - restore_dtype: optional dtype to cast targets to after restoring. - save_dtype: dtype to cast targets to before saving. - metric_name_to_monitor: Name of metric to monitor. Must be in the format - {run_name}/{tag_name} (e.g., 'train/accuracy', - 'inference_eval/glue_cola_v002/eval/accuracy'). - metric_mode: Mode to use to compare metric values. One of 'max' or 'min'. - keep_checkpoints_without_metrics: Whether to always keep (or delete) - checkpoints for which a metric value has not been found. - force_keep_period: When removing checkpoints, skip those who step is - divisible by force_keep_period (step % force_keep_period == 0). - keep_dataset_checkpoints: an optional maximum number of data iterators to - keep. If more than this number of data iterators exist after a save, the - oldest ones will be automatically deleted to save space. - """ - - def __init__( - self, # pytype: disable=annotation-type-mismatch # jnp-type - train_state: train_state_lib.TrainState, - partitioner: partitioning.BasePartitioner, - checkpoints_dir: str, - dataset_iterator: Optional[tf.data.Iterator] = None, - *, - keep: Optional[int] = None, - save_dtype: jnp.dtype = np.float32, - restore_dtype: Optional[jnp.dtype] = None, - metric_name_to_monitor: str = 'train/accuracy', - metric_mode: str = 'max', - keep_checkpoints_without_metrics: bool = True, - force_keep_period: Optional[int] = None, - keep_dataset_checkpoints: Optional[int] = None, - ): - super().__init__( - train_state, - partitioner, - checkpoints_dir, - dataset_iterator, - keep=keep, - save_dtype=save_dtype, - restore_dtype=restore_dtype, - keep_dataset_checkpoints=keep_dataset_checkpoints, - ) - if metric_mode not in ('max', 'min'): - raise ValueError('Unsupported `metric_mode`: %s' % metric_mode) - - self._metric_name_to_monitor = metric_name_to_monitor - self._metric_mode = metric_mode - self._keep_checkpoints_without_metrics = keep_checkpoints_without_metrics - self._force_keep_period = force_keep_period - logging.info( - 'Using SaveBestCheckpointer to keep %s best (%s) metric %s', - keep, - metric_mode, - metric_name_to_monitor, - ) - - def _filter_out_force_keep_period_steps(self, existing_steps): - """Filter out steps that are divisible by keep_period excluding the last.""" - if not existing_steps: - return existing_steps - - # Don't filter out the last step. - last_step = existing_steps.pop() # pytype: disable=attribute-error # dynamic-method-lookup - existing_steps = [ - s for s in existing_steps if s % self._force_keep_period != 0 - ] - return existing_steps + [last_step] - - def _remove_old_checkpoints(self): - """Deletes checkpoints if there are more than keep_checkpoints.""" - if not self.keep: - return - - existing_steps = self.all_steps() - if self._force_keep_period: - # Ignore checkpoints whose step is divisible by the keep period. - existing_steps = self._filter_out_force_keep_period_steps(existing_steps) - - # Artificially add 1 to `keep` since we always keep the latest checkpoint. - if len(existing_steps) <= self.keep + 1: - return - - # Synchronous fetch of new events for existing_steps. - metrics_by_step = populate_metrics_for_steps( - self.checkpoints_dir, self._metric_name_to_monitor, existing_steps - ) - logging.info('SaveBestcheckpointer: collected metrics %s', metrics_by_step) - - # Re-sort existing_steps by metric values while always keeping the latest - # checkpoint. - latest_checkpoint = existing_steps[-1] - existing_steps = existing_steps[:-1] - - if self._keep_checkpoints_without_metrics: - existing_steps = list( - filter(lambda s: s in metrics_by_step, existing_steps) - ) - - to_remove = len(existing_steps) - self.keep - if to_remove <= 0: - return - - # For any remaining steps without metrics, we assign a low/high value which - # will make them candidate for removal. If no metrics are found this sorting - # should preserve current order (oldest first). - not_found_value = float('-inf' if self._metric_mode == 'max' else 'inf') - existing_steps = sorted( - existing_steps, - key=lambda step: metrics_by_step.get(step, not_found_value), - reverse=(self._metric_mode != 'max'), - ) - existing_steps.append(latest_checkpoint) - - for step in existing_steps[:to_remove]: - checkpoint_utils.remove_checkpoint_dir(self._get_checkpoint_dir(step)) - - -def _no_optimizer_state(ckpt_contents: PyTree, use_orbax_format: bool) -> bool: - if use_orbax_format: - return True - try: - version = ckpt_contents.get('version', 0) - return version == 0 - except Exception as e: - raise ValueError('Failed to get version') from e - - -def _should_apply_transform_fns( - ckpt_contents: PyTree, use_orbax_format: bool -) -> bool: - if use_orbax_format: - return True - try: - version = ckpt_contents.get('version', 0) - return version >= 2 - except Exception as e: - raise ValueError('Failed to get version') from e - - -def _get_optimizer_state_dict( - ckpt_contents: PyTree, - optimizer_state: Mapping[str, Any], - state_transformation_fns: Sequence[RestoreStateTransformationFn], - use_orbax_format: bool = False, -): - """Extracts optimizer state dict contents and applies assignment map.""" - if _no_optimizer_state(ckpt_contents, use_orbax_format): - # This is a standard Flax checkpoint and may require remapping below. - ckpt_optimizer_state = ckpt_contents - else: - ckpt_optimizer_state = ckpt_contents['optimizer'] - - if _should_apply_transform_fns(ckpt_contents, use_orbax_format): - for fn in state_transformation_fns: - ckpt_optimizer_state = fn(ckpt_optimizer_state, optimizer_state) - return ckpt_optimizer_state - else: - version = ckpt_contents.get('version', 0) # pylint: disable=unreachable - raise ValueError( - 'Checkpoint versions earlier than 2 are not supported. ' # pylint: disable=unreachable - f'Got version: {version}' - ) - - -def _transform_state_and_infos( - state_dict: PyTree, - parameter_infos: PyTree, - state_transformation_fns: Sequence[SaveStateTransformationFn], -) -> Tuple[PyTree, PyTree]: - """Applies transformations to the state dict and parameter infos PyTrees.""" - for fn in state_transformation_fns: - state_dict, parameter_infos = fn(state_dict, parameter_infos) - return state_dict, parameter_infos - - -async def _read_ts( - param_info: _ParameterInfo, - maybe_tspec: Any, - ckpt_path: str, - restore_dtype: Optional[jnp.dtype] = None, - mesh: Optional[jax.sharding.Mesh] = None, - axes: Optional[PartitionSpec] = None, - params_on_devices: bool = True, -): - """Read from a tensorstore. - - If both `mesh` and `axes` are provided, the method will attempt to restore the - array as a jax.Array. - - Note: - We use param_infos as the first argument because this function is only used - in `jax.tree_util.tree_map` calls. In a tree multimap if the leaf of the - first tree is `None` then is is ignored, even if the second tree has a - subtree at that point. This means that when we are using something like a - MultiOptimizer we can set the parameter info for a variable to `None` and - we can skip processing it, even if the checkpoint has a subtree with things - like optimizer state variables in it. - - Args: - param_info: Information about how to read the parameter, host based sliced - reads and the like. - maybe_tspec: The tensorstore spec to read the parameter or some other - object. If this is an array then we will do a host based sliced read on it - (provided the param_info says to). Anything else we just return. - ckpt_path: A base location to use when resolving the relative paths in the - tensorstore spec. - restore_dtype: type to restore as. None indicates that no cast is requested. - mesh: jax.sharding.Mesh object for GDA restoration. - axes: jax.sharding.MeshAxes object for GDA restoration. - params_on_devices: Whether parameters should be allowed to be deserialized - to devices. - - Returns: - The array. Depending on the value `maybe_tspec` it might be read from - tensorstore, or it might be returned as is. Depending on the values in - param_info (specifically the `local_chunk_info`) it might be the full value - or a specific slice. - """ - # If saved as a numpy array, but a partitioned read is requested, return a - # slice of the array for that host. Otherwise, return the whole thing. - if isinstance(maybe_tspec, np.ndarray) and param_info: - return maybe_tspec - # If we have anything else that isn't a tensorstore spec just return it. - elif not isinstance(maybe_tspec, ts.Spec): - return maybe_tspec - - tmp_ts_spec_dict = maybe_tspec.to_json() - # Remove non-required params so that we can open Tensorstore - # that was created with a different set of params. - del tmp_ts_spec_dict['metadata']['chunks'] - del tmp_ts_spec_dict['metadata']['compressor'] - - # Convert the relative path in the spec to a path based on the checkpoint - # location. Path and gcs bucket (if applicable) information is updated - # in-place. - _update_ts_path_from_relative_to_absolute( - os.path.dirname(ckpt_path), tmp_ts_spec_dict - ) - - if param_info.shape is not None: - ts_spec_arr_shape = tuple(tmp_ts_spec_dict['metadata']['shape']) - # Check that the shapes of the array on disk match the expected shape based - # on the optimizer that is being restored. - if ts_spec_arr_shape != param_info.shape: - raise ValueError( - f'Shape of `{param_info.name}` in checkpoint ' - f'{ts_spec_arr_shape} does not match expected ' - f'{param_info.shape}.' - ) - - if ( - 'dtype' in tmp_ts_spec_dict and tmp_ts_spec_dict['dtype'] == 'uint16' - ) or ( - 'dtype' in tmp_ts_spec_dict['metadata'] - and tmp_ts_spec_dict['metadata']['dtype'] == ' Optional[_ParameterInfo]: - """Create _ParameterInfo that results in a full read.""" - # tspec is only None for `param_states` where the associated variable - # is not updated by any optimizers. By setting the parameter info for - # this to None, we can later short circut processing these subtrees - # during loading. - if maybe_tspec is None: - return None - local_chunk_info = None - tspec = None - if isinstance(maybe_tspec, ts.Spec): - tspec = maybe_tspec - local_chunk_info = partitioning.LocalChunkInfo( - slice=(slice(None, None),), replica_id=0 - ) - return _ParameterInfo( - name='', # We don't ever use the name. - shape=tuple(tspec.to_json()['metadata']['shape']) if tspec else None, - # We just believe the spec in the file. - ts_spec=tspec, - local_chunk_info=local_chunk_info, - axes=None, - ) - - -def find_checkpoint( - path: epath.PathLike, step: Optional[int] = None -) -> epath.PathLike: - """Find the checkpoint file based on paths and steps. - - Args: - path: The location of the checkpoint. Can point to the `model_dir`, the - checkpoint dir with a step, or the actual checkpoint file. - step: The step to load. Only used if you are pointing to the `model_dir` - - Raises: - ValueError if the checkpoint file can't be found. - - Returns: - The path to the checkpoint file. - """ - # If you aren't pointing at the msgpack checkpoint file - if gfile.isdir(path): - # If you didn't specify a step, try to get most recent step - step = latest_step(path) if step is None else step - path = get_checkpoint_dir(path, step) if step is not None else path - # Whether you supplied a step, found a step, or were already pointing at the - # step, you are not pointing at a step directory, so now point to the - # msgpack file. - path = os.path.join(path, 'checkpoint') - # You weren't point to a dir so you were pointing at the msgpack file. - # Check that we found a checkpoint file. - if not gfile.exists(path) or gfile.isdir(path): - raise ValueError(f'Path is not a valid checkpoint: {path}') - return path - - -def load_t5x_checkpoint( - path: str, - step: Optional[int] = None, - state_transformation_fns: Sequence[RestoreStateTransformationFn] = (), - remap: bool = True, - restore_dtype: Optional[jnp.dtype] = None, - lazy_parameters: bool = False, -) -> PyTree: - """Load a T5X checkpoint without pre-defining the optimizer. - - Note: - This only works for T5X checkpoints, not TF checkpoints. - - Args: - path: The location of the checkpoint. - step: The checkpoint from which step should be loaded. - state_transformation_fns: Transformations to apply, in order, to the state - after reading. - remap: Whether to rename the checkpoint variables to the newest version. - restore_dtype: optional dtype to cast targets to after restoring. If None, - no parameter casting is performed. - lazy_parameters: whether to load the parameters as LazyArrays to preserve - memory. - - Returns: - A nested dictionary of weights and parameter states from the checkpoint. - """ - start_time = time.time() - path = find_checkpoint(path, step) - logging.info('Restoring from checkpoint: %s', path) - - # The msgpack file will have all the info we need about the parameter layout. - with gfile.GFile(path, 'rb') as fp: - ckpt_contents = serialization.msgpack_restore(fp.read()) - - # If reading a ckpt that was written with gfile driver but the current - # session uses the gcs driver, convert the ckpt's driver to gcs. - if os.fspath(path).startswith('gs://'): - ckpt_contents = _maybe_update_ts_from_file_to_gcs(ckpt_contents) - # If a ckpt was saved in gcs and is being loaded locally, then convert the - # driver to file or gfile. If the ckpt was not saved in gcs, do not change. - else: - ckpt_contents = _maybe_update_ts_from_gcs_to_file(ckpt_contents) - - # Remap that variable names to the most recent formatting. - if remap: - ckpt_optimizer_state = _get_optimizer_state_dict( - ckpt_contents, {}, state_transformation_fns - ) - # If we aren't remapping names we at least need to index into the checkpoint - # file blob to make sure we are only dealing with the optimizer state. - else: - # Grab a subsection of the file depending on the version. - version = ckpt_contents.get('version', 0) - if version == 0: - ckpt_optimizer_state = ckpt_contents - else: - ckpt_optimizer_state = ckpt_contents['optimizer'] - - # Replace all dicts of tensorstore specs with actual `ts.Spec`s. - # When a checkpoint was trained using a MultiOptimizer, some of the parameter - # states may be set to `None` (when a parameter was untouched by any - # optimizer). We still needs references to these in our state so we keep - # empty nodes. - ckpt_optimizer_state_with_specs = state_utils.flatten_state_dict( - ckpt_optimizer_state, keep_empty_nodes=True - ) - ckpt_optimizer_state_with_specs = { - k: ts.Spec(v) if isinstance(v, dict) else v - for k, v in ckpt_optimizer_state_with_specs.items() - } - - # Create fake parameter info that results in reading the whole variable. - param_infos = { - k: fake_param_info(v) for k, v in ckpt_optimizer_state_with_specs.items() - } - - ckpt_optimizer_state_with_specs = traverse_util.unflatten_dict( - ckpt_optimizer_state_with_specs, sep='/' - ) - param_infos = traverse_util.unflatten_dict(param_infos, sep='/') - - def _create_lazy_awaitable_array( - param_info: _ParameterInfo, - maybe_ts_spec: Any, - ckpt_path: str, - restore_dtype: Optional[jnp.dtype], - ) -> LazyAwaitableArray: - get_fn = functools.partial( - _read_ts, - param_info, - maybe_ts_spec, - ckpt_path=ckpt_path, - restore_dtype=restore_dtype, - params_on_devices=False, - ) - return LazyAwaitableArray.from_tensor_store_spec_or_array( - maybe_ts_spec, get_fn, dtype=restore_dtype - ) - - state_dict = jax.tree_util.tree_map( - functools.partial( - _create_lazy_awaitable_array, - ckpt_path=path, - restore_dtype=restore_dtype, - ), - param_infos, - ckpt_optimizer_state_with_specs, - ) - - if not lazy_parameters: - future_state_dict = jax.tree_util.tree_map( - lambda x: x.get_async(), state_dict - ) - state_dict = _run_future_tree(future_state_dict) - - if restore_dtype is not None: - state_dict['target'] = _cast(state_dict['target'], restore_dtype) - - end_time = time.time() - monitoring.record_event_duration_secs( - _READ_CHECKPOINT_EVENT, end_time - start_time - ) - return state_dict - - -_OPTIMIZER_KEY = 'optimizer' -_VERSION_KEY = 'version' -_CHECKPOINTS_SUBDIR = 'checkpoints' -_STATE_KEY = 'state' -_DATASET_KEY = 'dataset' -_METRICS_KEY = 'metrics' -_FLAX_CHECKPOINT_FILE = 'checkpoint' - - -@dataclasses.dataclass -class _OrbaxParamInfo: - name: str - mesh_axes: partitioning.PartitionSpec - - -class DatasetCheckpointHandler(ocp.CheckpointHandler): - """A CheckpointHandler implementation that handles tf.data.Iterator.""" - - def __init__(self, checkpoint_filename: str, should_write_dataset_ckpt: bool): - self._checkpoint_filename = checkpoint_filename - self._should_write_dataset_ckpt = should_write_dataset_ckpt - - def save( - self, - directory: epath.Path, - args: 'DatasetArgs', - ): - """Saves the given item. - - Args: - directory: save location directory. - args: DatasetArgs (see below). - """ - if self._should_write_dataset_ckpt: - item = args.item - if item is None: - raise ValueError('Must provide item to save.') - if jax.process_count() > 1: - directory /= f'process_{jax.process_index()}-of-{jax.process_count()}' - directory.mkdir(parents=False, exist_ok=False) - if isinstance(item, tf.data.Iterator): - ckpt = tf.train.Checkpoint(ds=item) - ckpt.write(os.fspath(directory / self._checkpoint_filename)) - elif isinstance(item, clu.data.dataset_iterator.DatasetIterator): - item.save(os.fspath(directory / self._checkpoint_filename)) - - def restore( - self, - directory: epath.Path, - args: Optional['DatasetArgs'] = None, - ) -> Dataset: - """Restores the given item. - - Args: - directory: restore location directory. - args: DatasetArgs (see below). - - Returns: - a tf.data.Iterator restored from `directory`. - """ - if self._should_write_dataset_ckpt: - if args is None: - raise ValueError('Must provide args to restore.') - item = args.item - if jax.process_count() > 1: - directory /= f'process_{jax.process_index()}-of-{jax.process_count()}' - if isinstance(item, tf.data.Iterator): - ckpt = tf.train.Checkpoint(ds=item) - ckpt.read( - os.fspath(directory / self._checkpoint_filename) - ).assert_consumed() - elif isinstance(item, clu.data.dataset_iterator.DatasetIterator): - item.load(os.fspath(directory / self._checkpoint_filename)) - return item - - -@ocp.args.register_with_handler( - DatasetCheckpointHandler, for_save=True, for_restore=True -) -@dataclasses.dataclass -class DatasetArgs(ocp.args.CheckpointArgs): - item: Optional[Dataset] = None - - -def _step_from_train_state(train_state: train_state_lib.TrainState) -> int: - step = train_state.step - step = step.get() if isinstance(step, LazyArray) else step - step = get_local_data(step) - # Integer, to avoid side effects in the checkpoint path. - return int(step) - - -def _construct_save_args( - param_info: _OrbaxParamInfo, dtype: jnp.dtype -) -> ocp.SaveArgs: - """Create SaveArgs for Orbax saving.""" - if param_info.name.split('.')[0] != 'target': - dtype = None - return ocp.SaveArgs(dtype=dtype) - - -def _construct_restore_args( - param_info: _OrbaxParamInfo, - dtype: jnp.dtype, - mesh: jax.sharding.Mesh, -) -> ocp.RestoreArgs: - """Create RestoreArgs for Orbax restoration.""" - if not isinstance(param_info, _OrbaxParamInfo): # from fallback - return ocp.RestoreArgs(dtype=dtype) - if param_info.name.split('/')[0] != 'target': - dtype = None - if param_info.mesh_axes is None: - return ocp.RestoreArgs(dtype=dtype) - return ocp.ArrayRestoreArgs( - mesh=mesh, - mesh_axes=param_info.mesh_axes, - dtype=dtype, - ) - - -def _construct_orbax_param_infos( - train_state: train_state_lib.TrainState, - partitioner: partitioning.BasePartitioner, -) -> PyTree: - """Construct _OrbaxParamInfo tree for TrainState parameters.""" - param_names = traverse_util.unflatten_dict({ - k: '/'.join(k) - for k in traverse_util.flatten_dict( - train_state.state_dict(), keep_empty_nodes=True - ) - }) - mesh_axes = partitioner.get_mesh_axes(train_state).state_dict() - return jax.tree_util.tree_map(_OrbaxParamInfo, param_names, mesh_axes) - - -def _construct_orbax_restoration_transforms( - state_handler: ocp.PyTreeCheckpointHandler, - step: int, - directory: epath.Path, - state_dict: PyTree, - state_transformation_fns: Sequence[RestoreStateTransformationFn], - restore_args: PyTree, -) -> Tuple[PyTree, Any, PyTree]: - """Construct transformations and restoration arguments for Orbax classes.""" - # After transforms, may be a subset of keys: only the ones we actually need - # to restore. - state_subdir = ocp.utils.get_save_directory( - step, directory, name=_STATE_KEY, step_prefix=get_checkpoint_prefix() - ) - assert state_subdir.is_dir(), state_subdir - use_orbax_format = state_subdir.stem == _STATE_KEY # Standard Orbax format - structure, _ = state_handler._get_internal_metadata( # pylint: disable=protected-access - state_subdir - ) - # Note: Ideally we would use Orbax's `transform_fn` to do this logic, but - # the problem is we need to modify `restore_args`, and there isn't a great - # way to do that within Orbax. - state_dict_to_restore = _get_optimizer_state_dict( - structure, - state_dict, - state_transformation_fns, - use_orbax_format=use_orbax_format, - ) - # After transformations, state_dict_to_restore may still have extra keys - # relative to item (the eventual restoration structure). Extraneous keys - # need to be dropped. - state_dict_to_restore = state_utils.intersect_state( - state_dict_to_restore, state_dict - ) - restore_args = state_utils.intersect_state( - restore_args, state_dict_to_restore - ) - - def _transform_fn( - item_: PyTree, structure_: PyTree, param_infos_: PyTree - ) -> Tuple[PyTree, PyTree]: - # When this function is called from within PyTreeCheckpointHandler, - # transforms will already have been performed (see above), but use this - # function to hack param_infos to return the needed values. - # This structure is unneeded, because we already restored and transformed - # it. - del structure_, param_infos_ - - def _make_orbax_internal_metadata(value: Any, args: ocp.RestoreArgs): - if isinstance(value, ocp.metadata.tree.ValueMetadataEntry): - if value.value_type == 'scalar': - return ocp.metadata.tree.ValueMetadataEntry(value_type='scalar') - if isinstance(args, ocp.ArrayRestoreArgs): - value_type = 'jax.Array' - else: - value_type = 'np.ndarray' - return ocp.metadata.tree.ValueMetadataEntry(value_type=value_type) - else: - return value - - directory_ = ocp.utils.get_save_directory( - step, directory, name=_STATE_KEY, step_prefix=get_checkpoint_prefix() - ) - - def _modify_orbax_param_info(info, value): - if ocp.utils.leaf_is_placeholder(value): - name = ocp.utils.name_from_leaf_placeholder(value) - return dataclasses.replace( - info, name=name, path=directory_ / name, parent_dir=directory_ - ) - return info - - item_ = jax.tree.map(_make_orbax_internal_metadata, item_, restore_args) - param_infos_ = checkpoint_utils.get_restore_parameters(directory_, item_) - param_infos_ = jax.tree.map( - _modify_orbax_param_info, param_infos_, state_dict_to_restore - ) - return item_, param_infos_ - - return state_dict_to_restore, restore_args, _transform_fn - - -def _restore_from_tf_checkpoint( - full_state_dict: PyTree, - param_infos: PyTree, - train_state: train_state_lib.TrainState, - partitioner: partitioning.BasePartitioner, - restore_dtype: jnp.dtype, -) -> train_state_lib.TrainState: - """Restore from a TensorFlow-based T5 checkpoint.""" - full_state_dict = dict(full_state_dict) - - def _partition_parameter(maybe_arr: Any, param_info: _OrbaxParamInfo): - if isinstance(maybe_arr, np.ndarray) and param_info: - arr = maybe_arr - to_sharded = partitioner.partition( - lambda x: x, - in_axis_resources=None, - out_axis_resources=param_info.mesh_axes, - ) - return to_sharded(arr) - return maybe_arr - - if restore_dtype is not None: - full_state_dict['target'] = _cast(full_state_dict['target'], restore_dtype) - state_dict = jax.tree_util.tree_map( - _partition_parameter, - full_state_dict, - param_infos, - ) - - return train_state.restore_state(state_dict) - - -@gin.configurable -class OrbaxCheckpointManagerInterface: - """Wrapper for ocp.CheckpointManager.""" - - class _CheckpointManagerImpl(ocp.CheckpointManager): - """CheckpointManager implementation to deal with metrics update.""" - - def _get_old_steps_to_remove(self) -> List[int]: - """Update metrics for Orbax management, if available.""" - if self._track_best: - metric_name_to_monitor = self._options.metric_name_to_monitor # pytype: disable=attribute-error - step_to_metric = populate_metrics_for_steps( - os.fspath(self.directory), - metric_name_to_monitor, - self.all_steps(), - ) - for info in self._checkpoints: - if info.step in step_to_metric: - metrics = {metric_name_to_monitor: step_to_metric[info.step]} - info.metrics = metrics - return super()._get_old_steps_to_remove() - - def __init__( - self, - directory: str, - train_state: train_state_lib.TrainState, - partitioner: partitioning.BasePartitioner, - dataset_iterator: Optional[ - Union[tf.data.Iterator, clu.data.dataset_iterator.DatasetIterator] - ] = None, - save_dtype: Optional[jnp.dtype] = None, - restore_dtype: Optional[jnp.dtype] = None, - keep: Optional[int] = None, - period: Optional[int] = 1, - checkpoint_steps: Optional[Sequence[int]] = None, - keep_dataset_checkpoints: Optional[int] = None, - force_keep_period: Optional[int] = None, - metric_name_to_monitor: Optional[str] = None, - metric_mode: str = 'max', - keep_checkpoints_without_metrics: bool = True, - ): - """Performs Orbax setup given standard arguments from T5X.""" - del checkpoint_steps - del keep_dataset_checkpoints - self._train_state = train_state - self._partitioner = partitioner - if isinstance( - dataset_iterator, clu.data.dataset_iterator.TfDatasetIterator - ): - assert dataset_iterator._checkpoint - self._dataset_iterator = dataset_iterator - self._save_dtype = save_dtype - self._restore_dtype = restore_dtype - self._tmp_directory: Optional[epath.PathLike] = None - data_layout = partitioner.get_data_layout() - dataset_ckpt_name = ( - f'{_TRAIN_DS_PREFIX}-' - f'{data_layout.shard_id:03}-of-{data_layout.num_shards:03}' - ) - self._should_write_dataset_ckpt = ( - self._dataset_iterator and data_layout.is_first_host_in_replica_set - ) - self._state_handler = ocp.PyTreeCheckpointHandler(use_ocdbt=True) - item_handlers = { - _STATE_KEY: self._state_handler, - _DATASET_KEY: DatasetCheckpointHandler( - checkpoint_filename=dataset_ckpt_name, - should_write_dataset_ckpt=self._should_write_dataset_ckpt, - ), - } - - def best_fn(metrics): - return metrics[metric_name_to_monitor] - - options = ocp.CheckpointManagerOptions( - max_to_keep=keep, - save_interval_steps=period, - keep_period=force_keep_period, - best_fn=best_fn if metric_name_to_monitor is not None else None, - best_mode=metric_mode, - keep_checkpoints_without_metrics=keep_checkpoints_without_metrics, - cleanup_tmp_directories=True, - step_prefix=get_checkpoint_prefix(), - async_options=ocp.AsyncOptions( - timeout_secs=600, - ), - ) - options.metric_name_to_monitor = metric_name_to_monitor - self._options = options - - if not gfile.isdir(directory): - directory = os.path.dirname(directory) - self._manager = self._CheckpointManagerImpl( - directory=directory, - options=self._options, - item_handlers=item_handlers, - ) - - @property - def directory(self) -> epath.Path: - return self._manager.directory - - def all_steps(self) -> Sequence[int]: - return self._manager.all_steps() - - def latest_step(self) -> Optional[int]: - return self._manager.latest_step() - - def should_save(self, step: int) -> bool: - return self._manager.should_save(step) - - def wait_until_finished(self): - return self._manager.wait_until_finished() - - def close(self): - return self._manager.close() - - def save( - self, - train_state: train_state_lib.TrainState, - state_transformation_fns: Sequence[SaveStateTransformationFn] = (), - force: bool = True, - ) -> bool: - """Saves a checkpoint for the given train state. - - Args: - train_state: the train state to save. May contain a combination of - LazyArray objects and arrays (e.g., np.ndarray, jax.DeviceArray) - state_transformation_fns: Transformations to apply, in order, to the state - before writing. - force: Saves regardless of whether should_save is False. True by default - because should_save logic is handled externally to this class in T5X. - This is because of a feature that decouples actual step and step offset. - - Returns: - Whether the save was performed or not. - """ - start_time = time.time() - step = _step_from_train_state(train_state) - if not force and not self._manager.should_save(step): - return False - - # TODO(b/216649487) Test save-time state_transformation_fns. - state_dict, param_infos = _transform_state_and_infos( - train_state.state_dict(), - _construct_orbax_param_infos(self._train_state, self._partitioner), - state_transformation_fns, - ) - - # Arguments for saving interpretable by Orbax. - save_args = jax.tree_util.tree_map( - functools.partial(_construct_save_args, dtype=self._save_dtype), - param_infos, - ) - - # Separate savable items. - args = { - _STATE_KEY: ocp.args.PyTreeSave( - state_dict, - save_args=save_args, - ), - _DATASET_KEY: DatasetArgs(self._dataset_iterator), - } - args = ocp.args.Composite(**args) - saved = self._manager.save(step, args=args, force=force) - - # Record JAX monitoring events. - end_time = time.time() - monitoring.record_event_duration_secs( - _WRITE_CHECKPOINT_EVENT, end_time - start_time - ) - ocp.utils.record_saved_duration(start_time) - - return saved - - def restore( - self, - step: Optional[int] = None, - path: Optional[str] = None, - fallback_state: Optional[Mapping[str, Any]] = None, - state_transformation_fns: Sequence[RestoreStateTransformationFn] = (), - lazy_parameters: Optional[bool] = False, - ) -> train_state_lib.TrainState: - """Restores a TrainState from the given step or path. - - Note: can only provide one of `step` or `path`. - - Args: - step: the step number to restore from. - path: the full path to restore from. - fallback_state: a state dict of an optimizer to fall back to for loading - params that do not exist in the checkpoint (after applying all - `state_transformation_fns`), but do exist in `Checkpointer.optimizer`. - The union of `fallback_state` and state loaded from the checkpoint must - match `Checkpointer.optimizer`. - state_transformation_fns: Transformations to apply, in order, to the state - after reading. - lazy_parameters: whether to load the parameters as LazyArrays to preserve - memory. - - Returns: - The restored train state. - """ - if lazy_parameters: - logging.warning('Orbax does not support lazy restoration.') - start_time = time.time() - if step is not None and path is not None: - raise ValueError('Can only provide `step` or `path` but not both.') - directory = self.directory - if path is not None: - directory, step = get_step_from_checkpoint_dir(os.fspath(path)) - - # Check for legacy T5X checkpoint: If so, use legacy T5X - # checkpointer to restore the state. The following exclusive features of T5X - # checkpoint are skipped: DatasetIterator, [add more here when discovered] - try: - ckpt_path = find_checkpoint(directory, step) - except ValueError: - # `find_checkpoint` fails if the `.checkpoint` file isn't directly in - # the checkpoint directory. In this case, leave path as None and skip - # the legacy T5X checkpoint check. - ckpt_path = None - - if ckpt_path is not None: - ckpt_type = checkpoint_utils.detect_checkpoint_type( - ckpt_path, expected=checkpoint_utils.CheckpointTypes.ORBAX - ) - if ckpt_type is checkpoint_utils.CheckpointTypes.T5X_TF: - raise ValueError( - 'Attempting to restore a TensorFlow checkpoint as a native T5X ' - 'checkpoint. Use `restore_from_tf_checkpoint` instead. Path: ' - + ckpt_path - ) - elif ckpt_type is checkpoint_utils.CheckpointTypes.T5X: - legacy_checkpointer = Checkpointer( - self._train_state, - self._partitioner, - self.directory, - restore_dtype=self._restore_dtype, - ) - return legacy_checkpointer.restore( - path=path, - fallback_state=fallback_state, - state_transformation_fns=state_transformation_fns, - ) - - state_dict = self._train_state.state_dict() - # Returns a state dict rather than a train state. - param_infos = _construct_orbax_param_infos( - self._train_state, self._partitioner - ) - # Construct restoration arguments interpretable by Orbax. - restore_args = jax.tree_util.tree_map( - functools.partial( - _construct_restore_args, - dtype=self._restore_dtype, - mesh=self._partitioner.mesh, - ), - param_infos, - ) - # Handle T5X transformation functions, since they are specified differently - # than native Orbax transformation functions. - state_dict_to_restore, restore_args, transform_fn = ( - _construct_orbax_restoration_transforms( - self._state_handler, - step, - directory, - state_dict, - state_transformation_fns, - restore_args, - ) - ) - - # Construct separate items to restore. - args = { - _STATE_KEY: ocp.args.PyTreeRestore( - state_dict_to_restore, - restore_args=restore_args, - legacy_transform_fn=transform_fn, - ), - } - if self._should_write_dataset_ckpt: - args[_DATASET_KEY] = DatasetArgs(self._dataset_iterator) - args = ocp.args.Composite(**args) - restored = self._manager.restore(step, args=args, directory=directory) - state_dict = restored[_STATE_KEY] - if self._should_write_dataset_ckpt: - self._dataset_iterator = restored[_DATASET_KEY] - - # Merge restored state dict with fallback state to fill in any remaining - # params. - if fallback_state is not None: - state_dict = state_utils.merge_state(state_dict, fallback_state) - - # After restoration, some values may still be non-sharded arrays from - # fallback state. - def _maybe_make_sharded_array_helper(arr, info): - if arr is not None: - return _maybe_make_sharded_array( - arr, - self._partitioner.mesh, - axes=info.mesh_axes, - restore_dtype=self._restore_dtype, - ) - - state_dict = jax.tree_util.tree_map( - _maybe_make_sharded_array_helper, - state_dict, - param_infos, - is_leaf=lambda x: x is None, - ) - - train_state = self._train_state.restore_state(state_dict) - - end_time = time.time() - monitoring.record_event_duration_secs( - _READ_CHECKPOINT_EVENT, end_time - start_time - ) - - return train_state - - def restore_from_tf_checkpoint( - self, - path_or_dir: str, - strict: bool = True, - translator: Optional[checkpoint_importer.CheckpointTranslator] = None, - ) -> train_state_lib.TrainState: - """Restore from a TensorFlow-based T5 checkpoint.""" - full_state_dict = checkpoint_importer.restore_from_t5_checkpoint( - self._train_state.state_dict(), - path_or_dir, - strict=strict, - translator=translator, - ) - return _restore_from_tf_checkpoint( - full_state_dict, - _construct_orbax_param_infos(self._train_state, self._partitioner), - self._train_state, - self._partitioner, - self._restore_dtype, - ) - - -@gin.configurable -class CheckpointManagerConstructor(typing_extensions.Protocol): - """A function that returns a checkpoints.CheckpointManager. - - This type annotation allows users to partially bind args to the constructors - of CheckpointManager subclasses without triggering type errors. - """ - - def __call__( - self, - directory: str, - train_state: train_state_lib.TrainState, - partitioner: partitioning.BasePartitioner, - dataset_iterator: Optional[tf.data.Iterator] = None, - save_dtype: Optional[jnp.dtype] = None, - restore_dtype: Optional[jnp.dtype] = None, - keep: Optional[int] = None, - period: Optional[int] = None, - force_keep_period: Optional[int] = None, - checkpoint_steps: Optional[Sequence[int]] = None, - ) -> OrbaxCheckpointManagerInterface: - """CheckpointManager constructor.""" - pass diff --git a/t5x-main/t5x/checkpoints_utils.py b/t5x-main/t5x/checkpoints_utils.py deleted file mode 100644 index 89163d9601ddd7668938934d7bae3a12c3103d62..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/checkpoints_utils.py +++ /dev/null @@ -1,89 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Utilities for discovering checkpoints on disk. - -This library contains only the relatively small/simple functions needed to -identify checkpoint directories. They can be used on a controller-type job that -doesn't need the ability to actually read the checkpoints and can thus -significantly reduce its binary size by not linking all the Jax libraries. -""" - -import os -import re -from typing import Optional, Sequence, Tuple - -from etils import epath -import gin -from tensorflow.io import gfile - - -_TRAIN_DS_PREFIX = 'train_ds' - - -@gin.configurable -def get_checkpoint_prefix(prefix='checkpoint'): - return prefix - - -def all_steps(checkpoints_dir: str) -> Sequence[int]: - """Returns list of available step numbers in ascending order.""" - glob_pattern = os.path.join(checkpoints_dir, 'checkpoint_*') - checkpoint_paths = gfile.glob(glob_pattern) - re_pattern = re.compile(r'.*/checkpoint_(\d+)$') - matches = [re_pattern.match(ckpt) for ckpt in checkpoint_paths] - return sorted(int(match.group(1)) for match in matches if match) - - -def all_dataset_checkpoint_steps(checkpoints_dir: str) -> Sequence[int]: - """Returns available dataset checkpoint step numbers in ascending order.""" - glob_pattern = os.path.join( - checkpoints_dir, 'checkpoint_*', f'{_TRAIN_DS_PREFIX}-*' - ) - train_ds_paths = gfile.glob(glob_pattern) - re_pattern = re.compile(r'.*/checkpoint_(\d+)/.*$') - matches = [re_pattern.match(path) for path in train_ds_paths] - return sorted(set(int(match.group(1)) for match in matches if match)) - - -def latest_step(checkpoints_dir: str) -> Optional[int]: - """Returns latest step number or None if no checkpoints exist.""" - steps = all_steps(checkpoints_dir) - if not steps: - return None - return steps[-1] - - -def get_checkpoint_dir( - checkpoints_dir: epath.PathLike, - step: int, - step_format_fixed_length: Optional[int] = None, -) -> epath.PathLike: - """Returns path to a checkpoint dir given a parent directory and step.""" - step_str = ( - f'{step:0{step_format_fixed_length}d}' - if step_format_fixed_length is not None - else str(step) - ) - return os.path.join(checkpoints_dir, f'{get_checkpoint_prefix()}_{step_str}') - - -def get_step_from_checkpoint_dir(checkpoints_dir: str) -> Tuple[str, int]: - """Returns a step number and the parent directory.""" - if checkpoints_dir.endswith('/'): - checkpoints_dir = checkpoints_dir[:-1] - parent, checkpoint = os.path.split(checkpoints_dir) - if get_checkpoint_prefix() not in checkpoint: - raise ValueError('Found improperly formatted checkpoint directory.') - return parent, int(checkpoint.replace(f'{get_checkpoint_prefix()}_', '')) diff --git a/t5x-main/t5x/config_utils.py b/t5x-main/t5x/config_utils.py deleted file mode 100644 index 23355bed0001ae791bf6b4db58b91886dceba378..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/config_utils.py +++ /dev/null @@ -1,215 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Utilities for configuring T5X binaries.""" - -import copy -import inspect -from typing import Callable, Optional, TypeVar - -from absl import flags -from absl import logging -from clu import metric_writers -from etils import epath -import fiddle as fdl -from fiddle import absl_flags as fdl_flags -from fiddle import selectors -from fiddle.experimental import serialization -import jax -from t5x import gin_utils -from t5x import utils - - -FLAGS = flags.FLAGS - - -def using_fdl(): - """Returns true if any Fiddle configuration flags are set.""" - return ( - FLAGS.fdl_config is not None - or FLAGS.fdl_config_file is not None - or FLAGS.fdl_help - ) - - -T = TypeVar('T') - - -def prepare_to_summarize(config: fdl.Buildable[T]) -> fdl.Buildable[T]: - """Update `config` so its `summarize_fiddle_config` calls can access it.""" - new_config = copy.deepcopy(config) - # Here because get_current_fiddle_config is a lambda that returns config, - # and not config itself, this does not cause recursive trouble for fiddle. - selectors.select(new_config, direct_summarize_fiddle_config).set( - get_current_fiddle_config=lambda: config - ) - return new_config - - -def sanitize_summary_getter(config: fdl.Buildable[T]) -> fdl.Buildable[T]: - """Update `config` to remove `get_current_fiddle_config` calls.""" - new_config = copy.deepcopy(config) - # Here because get_current_fiddle_config is a lambda that returns config, - # and not config itself, this does not cause recursive trouble for fiddle. - selectors.select(new_config, direct_summarize_fiddle_config).set( - get_current_fiddle_config=None - ) - return new_config - - -def direct_summarize_fiddle_config( - model_dir: str, - summary_writer: Optional[metric_writers.MetricWriter], - step: int, - get_current_fiddle_config: Optional[Callable[[], fdl.Buildable]] = None, -): - """Writes fiddle config to the model dir and TensorBoard summary. - - When passing this function to your fiddle config, don't pass the private - function; instead pass the `fdl.Partial` version of `summarize_fiddle_config`. - - Args: - model_dir: Model directory to write to. - summary_writer: MetricWriter, if any. - step: Current step. - get_current_fiddle_config: This will be filled in by - `t5x.config_utils.prepare_to_summarize()`. - """ - if jax.process_index() != 0: - return - if not get_current_fiddle_config: - raise ValueError( - 'get_current_fiddle_config() not provided. Please pass your fiddle ' - 'config through t5x.config_utils.prepare_to_summarize prior to ' - 'building it.' - ) - config = get_current_fiddle_config() - config = sanitize_summary_getter(config) - config_str = str(config) - - model_dir_path = epath.Path(model_dir) - model_dir_path.mkdir(parents=True, exist_ok=True) - - # Write the config. - (model_dir_path / 'fiddle_config.txt').write_text(config_str) - - # Try to serialize to json as well - try: - config_json = serialization.dump_json(config) - (model_dir_path / 'fiddle_config.json').write_text(config_json) - except serialization.UnserializableValueError as e: - logging.warning( - 'Unable to JSON Serialize fiddle config, skipping. Error: %s', e - ) - - if summary_writer is not None: - summary_writer.write_texts(step, {'fiddle_config': config_str}) - summary_writer.flush() - - -# Pass this when configuring the argument `summarize_config_fn`. -summarize_fiddle_config = fdl.Partial(direct_summarize_fiddle_config) - - -def config_with_fiddle( - function: Callable[..., T], -) -> fdl.Buildable[Callable[..., T]]: - """Configure and build a T5X launcher from Fiddle command line flags. - - The output config, when called via `fdl.build()`, will execute `function`. - - Args: - function: A function that launches a T5X job, e.g., `train`, `eval`, ... - - Returns: - The buildable of the function or object, depending on whether - `--fdl_config_file` or `--fdl_config` was passed. - - Raises: - AssertionError: If `not using_fdl()`. - ValueError: If both fiddle and gin arguments were passed on the command - line. - ValueError: If both `--fdl_config_file` and `--fdl_config` were passed. - ValueError: If the object built via `--fdl_config` does not build as a - call to `function`. - """ - assert using_fdl(), 'No fiddle command line flags found' - if (FLAGS.fdl_config_file or FLAGS.fdl_config) and ( - FLAGS.gin_file or FLAGS.gin_bindings - ): - raise ValueError( - 'Must pass exactly one of `--fdl_config_file`, `--fdl_config`, or ' - '`--gin_file` / `--gin_bindings`. Got: ' - f'--fdl_config_file={FLAGS.fdl_config_file} ' - f'--fdl_config={FLAGS.fdl_config} ' - f'--gin_file={FLAGS.gin_file}.' - f'--gin_bindings={FLAGS.gin_bindings}.' - ) - if FLAGS.fdl_config_file and FLAGS.fdl_config: - raise ValueError( - 'Must pass exactly one of `--fdl_config_file` or `--fdl_config`. Got: ' - f'--fdl_config_file={FLAGS.fdl_config_file} ' - f'--fdl_config={FLAGS.fdl_config}.' - ) - - if FLAGS.fdl_config_file: - # Fill in the launcher function args using a fiddle config json. - config = fdl_flags.create_buildable_from_flags(function) - elif FLAGS.fdl_config: - # Build a launcher object using a fiddle config module+function. - config = fdl_flags.create_buildable_from_flags(function, allow_imports=True) - else: - raise AssertionError('Should not get to this point.') - - # If this is a fdl.Config we want to convert it to - # fdl.Partial so that function() does not execute when - # fdl.build(config) is called. - config = fdl.cast(fdl.Partial, config) - - config_module = inspect.getmodule(config.__fn_or_cls__) - function_module = inspect.getmodule(function) - - # Best effort to ensure that config and function match, even if the json - # defines a different alias to the same module, like __main__ ~= t5x.train. - if (config.__fn_or_cls__.__qualname__ != function.__qualname__) or ( - inspect.getsource(config_module) != inspect.getsource(function_module) - ): - - def module_and_name(fn: Callable[..., T]) -> str: - return '.'.join((fn.__module__, fn.__qualname__)) - - raise ValueError( - 'Expected fiddle flags to configure function ' - f'{module_and_name(function)} but it configured ' - f'{module_and_name(config.__fn_or_cls__)}.\n\nConfig: {config}' - ) - - # Ensure that summarize_fiddle_config calls will work. - config = prepare_to_summarize(config) - - return config - - -def run(main): - """Wrapper for app.run that rewrites jax, gin, and fiddle flags.""" - - def flags_parser(args): - args = gin_utils.rewrite_gin_args(args) - return fdl_flags.flags_parser(args) - - jax.config.parse_flags_with_absl() - if using_fdl(): - utils.run_main(main, flags_parser=flags_parser) - else: - gin_utils.run(main) diff --git a/t5x-main/t5x/configs/__init__.py b/t5x-main/t5x/configs/__init__.py deleted file mode 100644 index a52d4f9529506a53a19a2903bc0796383eb56b78..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/configs/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""This empty file is needed for loading the gin files in this directory.""" diff --git a/t5x-main/t5x/configs/runs/__init__.py b/t5x-main/t5x/configs/runs/__init__.py deleted file mode 100644 index 0146a9b5441d4b6476a922553e6def9c1e3b598c..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/configs/runs/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This empty file is needed for loading the gin files in this directory. diff --git a/t5x-main/t5x/configs/runs/eval.gin b/t5x-main/t5x/configs/runs/eval.gin deleted file mode 100644 index 278b92e7ca51d4a12785b4befb11d85aea400e2c..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/configs/runs/eval.gin +++ /dev/null @@ -1,68 +0,0 @@ -# Defaults for eval.py. -# -# -# You must also include a binding for MODEL. -# -# Required to be set: -# -# - MIXTURE_OR_TASK_NAME: The SeqIO Task/Mixture to evaluate on -# - CHECKPOINT_PATH: The model checkpoint to evaluate -# - EVAL_OUTPUT_DIR: The dir to write results to. -# -# -# Commonly overridden options: -# -# - DatasetConfig.split -# - DatasetConfig.batch_size -# - DatasetConfig.use_cached -# - RestoreCheckpointConfig.mode -# - PjitPartitioner.num_partitions -from __gin__ import dynamic_registration - -import __main__ as eval_script -import seqio -from t5x import partitioning -from t5x import utils - - -# Must be overridden -MIXTURE_OR_TASK_NAME = %gin.REQUIRED -CHECKPOINT_PATH = %gin.REQUIRED -EVAL_OUTPUT_DIR = %gin.REQUIRED -TASK_FEATURE_LENGTHS = None # auto-computes the maximum features length to use. - -# DEPRECATED: Import the this module in your gin file. -MIXTURE_OR_TASK_MODULE = None - -eval_script.evaluate: - model = %MODEL # imported from separate gin file - dataset_cfg = @utils.DatasetConfig() - partitioner = @partitioning.PjitPartitioner() - restore_checkpoint_cfg = @utils.RestoreCheckpointConfig() - output_dir = %EVAL_OUTPUT_DIR - inference_evaluator_cls = @seqio.Evaluator - -partitioning.PjitPartitioner: - num_partitions = 1 - logical_axis_rules = @partitioning.standard_logical_axis_rules() - -seqio.Evaluator: - logger_cls = [@seqio.PyLoggingLogger, @seqio.TensorBoardLogger, @seqio.JSONLogger] - num_examples = None # Use all examples in the dataset. - use_memory_cache = True - -utils.DatasetConfig: - mixture_or_task_name = %MIXTURE_OR_TASK_NAME - task_feature_lengths = %TASK_FEATURE_LENGTHS - split = 'test' - batch_size = 32 - shuffle = False - seed = 42 - use_cached = False - pack = False - use_custom_packing_ops = False - module = %MIXTURE_OR_TASK_MODULE - -utils.RestoreCheckpointConfig: - path = %CHECKPOINT_PATH - mode = 'specific' diff --git a/t5x-main/t5x/configs/runs/export.gin b/t5x-main/t5x/configs/runs/export.gin deleted file mode 100644 index 031cedd1311553c04dc0eb4a1ef2ac6ed5fd70f1..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/configs/runs/export.gin +++ /dev/null @@ -1,103 +0,0 @@ -# Defaults for single_core_export.py. -# -# You must also include a binding for MODEL. -# -# Required to be set: -# -# - TASK_FEATURE_LENGTHS: The lengths per key in the SeqIO Task to trim features -# to. -# - CHECKPOINT_PATH: The model checkpoint to use for inference -# - MODEL_OUTPUT_DIR: The dir to write results to. -# This can be a dict (recommended) whereby the 'cpu' key specifies where -# the CPU model should be written. -# Alternatively (legacy), this can be a simple directory path as a string. -# It must end in a version number, whereby the CPU model is written in an -# adjacent directory with _cpu appended to the parent directory, with the -# same version directory inside. That is, confusingly, the CPU model is -# not written to the exact directory you specified. -# - MODEL_NAME: Name of model, like "/ml/user/half_plus_two". -# -# Commonly overridden options: -# -# warmup_examples: Optional[List[str]] = None -# jit_compile: bool = False - -from __gin__ import dynamic_registration - -import seqio - -from t5x import checkpoints -from t5x import models -from t5x import partitioning -from t5x import utils -from t5x import export_lib - -# Must be overridden -OUTPUT_FEATURES = %gin.REQUIRED -TASK_FEATURE_LENGTHS = %gin.REQUIRED -CHECKPOINT_PATH = %gin.REQUIRED -MODEL_OUTPUT_DIR = %gin.REQUIRED -MODEL_NAME = %gin.REQUIRED -BATCH_SIZE = None -BEAM_SIZE = 1 - -OUTPUT_FEATURES = {'inputs': @inputs/seqio.Feature(), 'targets': @outputs/seqio.Feature()} - -# Plumbing to extract the vocabulary directly from MODEL. This is needed to -# tokenize the features from the saved model inputs we aren't provided with -# vocabularies via a Task. -inputs/seqio.Feature.vocabulary = @models.get_input_vocabulary() -models.get_input_vocabulary.model = %MODEL # imported from separate gin file -outputs/seqio.Feature.vocabulary = @models.get_output_vocabulary() -models.get_output_vocabulary.model = %MODEL # imported from separate gin file - - -# Typical for inference settings: -ACTIVATION_DTYPE = 'bfloat16' - -export_lib.save: - model = %MODEL # imported from separate gin file - inference_mode = 'predict' - restore_checkpoint_cfg = @utils.RestoreCheckpointConfig() - exportable_module_cls = @export_lib.ExportableModule - create_preprocessor_fn = @export_lib.create_preprocessor - create_inference_function_fn = @export_lib.create_inference_function - create_postprocessor_fn = @export_lib.create_postprocessor - create_polymorphic_shapes_fn = @export_lib.create_batch_polymorphic_shapes - write_warmup_example_fn = @export_lib.write_warmup_examples - partitioner = @partitioning.PjitPartitioner() - output_features = %OUTPUT_FEATURES - task_feature_lengths = %TASK_FEATURE_LENGTHS - output_dir = %MODEL_OUTPUT_DIR - model_name = %MODEL_NAME - batch_size = %BATCH_SIZE - native_lowering = True - -utils.RestoreCheckpointConfig: - path = %CHECKPOINT_PATH - mode = 'specific' - dtype = 'bfloat16' - checkpointer_cls = @checkpoints.Checkpointer - -export_lib.create_preprocessor: - output_features = %OUTPUT_FEATURES - task_feature_lengths = %TASK_FEATURE_LENGTHS - -export_lib.create_inference_function: - output_len = None - -export_lib.create_postprocessor: - output_feature_names = None - -export_lib.ExportableModule: - jit_compile = True - use_batch_function = False - -partitioning.PjitPartitioner: - num_partitions = 1 - params_on_devices = True - logical_axis_rules = @partitioning.standard_logical_axis_rules() - -models.EncoderDecoderModel.predict_batch_with_aux: - num_decodes = %BEAM_SIZE - return_all_decodes = True diff --git a/t5x-main/t5x/configs/runs/export_seqio.gin b/t5x-main/t5x/configs/runs/export_seqio.gin deleted file mode 100644 index 0a2f37eb014f90291abbd9b72d755c247d0764a6..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/configs/runs/export_seqio.gin +++ /dev/null @@ -1,21 +0,0 @@ -from __gin__ import dynamic_registration - -from t5x import export_lib -from t5x import partitioning - -include 't5x/configs/runs/export.gin' - - -MIXTURE_OR_TASK_NAME = %gin.REQUIRED - -export_lib.save: - create_preprocessor_fn = @export_lib.create_preprocessor_from_task - mixture_or_task_name = %MIXTURE_OR_TASK_NAME - output_features = None - -export_lib.create_preprocessor_from_task: - model = %MODEL - task_feature_lengths = %TASK_FEATURE_LENGTHS - task_name = %MIXTURE_OR_TASK_NAME - serialized_examples = True - run_precache = False diff --git a/t5x-main/t5x/configs/runs/finetune.gin b/t5x-main/t5x/configs/runs/finetune.gin deleted file mode 100644 index bb8e19f3d6bbad5ffbe9b82fd105f21de01a6c28..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/configs/runs/finetune.gin +++ /dev/null @@ -1,155 +0,0 @@ -# Defaults for finetuning with train.py. -# -# -# You must also include a binding for MODEL. -# -# Required to be set: -# -# - MIXTURE_OR_TASK_NAME -# - TASK_FEATURE_LENGTHS -# - TRAIN_STEPS # includes pretrain steps -# - MODEL_DIR # automatically set when using xm_launch -# - INITIAL_CHECKPOINT_PATH -# -# When running locally, it needs to be passed in the `gin.MODEL_DIR` flag. -# -# `TRAIN_STEPS` should include pre-training steps, e.g., if pre-trained ckpt -# has 1M steps, TRAIN_STEPS = 1.1M will perform 0.1M fine-tuning steps. -# -# Otherwise, use TRAIN_STEPS_RELATIVE to specify the number of additional -# training steps to perform on top of the initial checkpoint. -# -# Commonly overridden options: -# - DROPOUT_RATE -# - BATCH_SIZE -# - PjitPartitioner.num_partitions -# - Trainer.num_microbatches -# - USE_CACHED_TASKS: Whether to look for preprocessed SeqIO data, or preprocess -# on the fly. Most common tasks are cached, hence this is set to True by -# default. - -from __gin__ import dynamic_registration - -import __main__ as train_script -import seqio -from t5x import gin_utils -from t5x import partitioning -from t5x import utils -from t5x import trainer - -# Must be overridden -MODEL_DIR = %gin.REQUIRED -MIXTURE_OR_TASK_NAME = %gin.REQUIRED -TASK_FEATURE_LENGTHS = %gin.REQUIRED -MIXTURE_OR_TASK_MODULE = %gin.REQUIRED -TRAIN_STEPS = %gin.REQUIRED -INITIAL_CHECKPOINT_PATH = %gin.REQUIRED - -# Commonly overridden -DROPOUT_RATE = 0.1 -USE_CACHED_TASKS = True -BATCH_SIZE = 128 - -# Sometimes overridden -EVAL_STEPS = 20 -EVAL_PERIOD = 1000 - -# Convenience overrides. -EVALUATOR_USE_MEMORY_CACHE = True -EVALUATOR_NUM_EXAMPLES = None # Use all examples in the infer_eval dataset. -JSON_WRITE_N_RESULTS = None # Write all inferences. -# HW RNG is faster than SW, but has limited determinism. -# Most notably it is not deterministic across different -# submeshes. -USE_HARDWARE_RNG = False -# None always uses faster, hardware RNG -RANDOM_SEED = None -TRAIN_STEPS_RELATIVE = None - -# DEPRECATED: Import the this module in your gin file. -MIXTURE_OR_TASK_MODULE = None - -train_script.train: - model = %MODEL # imported from separate gin file - model_dir = %MODEL_DIR - train_dataset_cfg = @train/utils.DatasetConfig() - train_eval_dataset_cfg = @train_eval/utils.DatasetConfig() - infer_eval_dataset_cfg = @infer_eval/utils.DatasetConfig() - checkpoint_cfg = @utils.CheckpointConfig() - partitioner = @partitioning.PjitPartitioner() - trainer_cls = @trainer.Trainer - total_steps = %TRAIN_STEPS - eval_steps = %EVAL_STEPS - eval_period = %EVAL_PERIOD - relative_steps = %TRAIN_STEPS_RELATIVE - random_seed = %RANDOM_SEED - use_hardware_rng = %USE_HARDWARE_RNG - summarize_config_fn = @gin_utils.summarize_gin_config - inference_evaluator_cls = @seqio.Evaluator - -partitioning.PjitPartitioner: - num_partitions = 1 - model_parallel_submesh = None - logical_axis_rules = @partitioning.standard_logical_axis_rules() - -seqio.Evaluator: - logger_cls = [@seqio.PyLoggingLogger, @seqio.TensorBoardLogger, @seqio.JSONLogger] - num_examples = %EVALUATOR_NUM_EXAMPLES - use_memory_cache = %EVALUATOR_USE_MEMORY_CACHE - -seqio.JSONLogger: - write_n_results = %JSON_WRITE_N_RESULTS - -train/utils.DatasetConfig: - mixture_or_task_name = %MIXTURE_OR_TASK_NAME - task_feature_lengths = %TASK_FEATURE_LENGTHS - split = 'train' - batch_size = %BATCH_SIZE - shuffle = True - seed = None # use a new seed each run/restart - use_cached = %USE_CACHED_TASKS - pack = True - module = %MIXTURE_OR_TASK_MODULE - -train_eval/utils.DatasetConfig: - mixture_or_task_name = %MIXTURE_OR_TASK_NAME - task_feature_lengths = %TASK_FEATURE_LENGTHS - split = 'validation' - batch_size = %BATCH_SIZE - shuffle = False - seed = 42 - use_cached = %USE_CACHED_TASKS - pack = True - module = %MIXTURE_OR_TASK_MODULE - -infer_eval/utils.DatasetConfig: - mixture_or_task_name = %MIXTURE_OR_TASK_NAME - task_feature_lengths = None # compute max - split = 'validation' - batch_size = %BATCH_SIZE - shuffle = False - seed = 42 - use_cached = %USE_CACHED_TASKS - pack = False - module = %MIXTURE_OR_TASK_MODULE - -utils.CheckpointConfig: - restore = @utils.RestoreCheckpointConfig() - save = @utils.SaveCheckpointConfig() -utils.RestoreCheckpointConfig: - path = %INITIAL_CHECKPOINT_PATH - mode = 'specific' - dtype = 'float32' -utils.SaveCheckpointConfig: - period = 5000 - dtype = 'float32' - keep = None # keep all checkpoints - save_dataset = False # don't checkpoint dataset state - -trainer.Trainer: - num_microbatches = None - learning_rate_fn = @utils.create_learning_rate_scheduler() -utils.create_learning_rate_scheduler: - factors = 'constant' - base_learning_rate = 0.001 - warmup_steps = 1000 diff --git a/t5x-main/t5x/configs/runs/infer.gin b/t5x-main/t5x/configs/runs/infer.gin deleted file mode 100644 index 0918d2f4843d698cf27787c62f7e09cf81c1e835..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/configs/runs/infer.gin +++ /dev/null @@ -1,71 +0,0 @@ -# Defaults for infer.py. -# -# -# You must also include a binding for MODEL. -# -# Required to be set: -# -# - MIXTURE_OR_TASK_NAME: The SeqIO Task/Mixture to use for inference -# - TASK_FEATURE_LENGTHS: The lengths per key in the SeqIO Task to trim features -# to. -# - CHECKPOINT_PATH: The model checkpoint to use for inference -# - INFER_OUTPUT_DIR: The dir to write results to. -# -# -# Commonly overridden options: -# -# - infer.mode -# - infer.checkpoint_period -# - infer.shard_id -# - infer.num_shards -# - DatasetConfig.split -# - DatasetConfig.batch_size -# - DatasetConfig.use_cached -# - RestoreCheckpointConfig.is_tensorflow -# - RestoreCheckpointConfig.mode -# - PjitPartitioner.num_partitions -from __gin__ import dynamic_registration - -import __main__ as infer_script -from t5x import partitioning -from t5x import utils - -# Must be overridden -MIXTURE_OR_TASK_NAME = %gin.REQUIRED -TASK_FEATURE_LENGTHS = %gin.REQUIRED -CHECKPOINT_PATH = %gin.REQUIRED -INFER_OUTPUT_DIR = %gin.REQUIRED - -# DEPRECATED: Import the this module in your gin file. -MIXTURE_OR_TASK_MODULE = None - -infer_script.infer: - mode = 'predict' - model = %MODEL # imported from separate gin file - output_dir = %INFER_OUTPUT_DIR - dataset_cfg = @utils.DatasetConfig() - partitioner = @partitioning.PjitPartitioner() - restore_checkpoint_cfg = @utils.RestoreCheckpointConfig() - checkpoint_period = 100 - shard_id = 0 - num_shards = 1 - -partitioning.PjitPartitioner: - num_partitions = 1 - logical_axis_rules = @partitioning.standard_logical_axis_rules() - -utils.DatasetConfig: - mixture_or_task_name = %MIXTURE_OR_TASK_NAME - module = %MIXTURE_OR_TASK_MODULE - task_feature_lengths = %TASK_FEATURE_LENGTHS - use_cached = False - split = 'test' - batch_size = 32 - shuffle = False - seed = 0 - pack = False - -utils.RestoreCheckpointConfig: - path = %CHECKPOINT_PATH - mode = 'specific' - dtype = 'bfloat16' diff --git a/t5x-main/t5x/configs/runs/infer_from_tfexample_file.gin b/t5x-main/t5x/configs/runs/infer_from_tfexample_file.gin deleted file mode 100644 index 5d62b27555ecfef3cd801098fe640ac09eff744c..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/configs/runs/infer_from_tfexample_file.gin +++ /dev/null @@ -1,90 +0,0 @@ -# Defaults for infer.py if using a TFExample file as input. -# -# -# The features from each TFExample are tokenized using the model's vocabulary. -# By default, the inputs feature is assumed to be keyed as 'inputs', but this -# can be overridden with `create_task_from_tfexample_file.inputs_key`. -# -# You must also include a binding for MODEL. -# -# Required to be set: -# -# - TF_EXAMPLE_FILE_PATHS: The path to read TF Examples from. -# - TF_EXAMPLE_FILE_TYPE: The type of file to read TF Examples from. Currently -# supported: 'tfrecord', 'recordio', 'sstable'. -# - FEATURE_LENGTHS: The maximum length per feature in the TF Examples. -# - CHECKPOINT_PATH: The model checkpoint to use for inference -# - INFER_OUTPUT_DIR: The dir to write results to. -# -# -# Commonly overridden options: -# -# - infer.mode -# - infer.checkpoint_period -# - infer.shard_id -# - infer.num_shards -# - create_task_from_tfexample_file.inputs_key -# - create_task_from_tfexample_file.targets_key -# - DatasetConfig.split -# - DatasetConfig.batch_size -# - RestoreCheckpointConfig.mode -# - PjitPartitioner.num_partitions -from __gin__ import dynamic_registration - -import __main__ as infer_script -import seqio -from t5x import models -from t5x import partitioning -from t5x import utils - -# Must be overridden -TF_EXAMPLE_FILE_PATHS = %gin.REQUIRED -TF_EXAMPLE_FILE_TYPE = %gin.REQUIRED -FEATURE_LENGTHS = %gin.REQUIRED -CHECKPOINT_PATH = %gin.REQUIRED -INFER_OUTPUT_DIR = %gin.REQUIRED - -infer_script.infer: - mode = 'predict' - model = %MODEL # imported from separate gin file - output_dir = %INFER_OUTPUT_DIR - dataset_cfg = @utils.DatasetConfig() - partitioner = @partitioning.PjitPartitioner() - restore_checkpoint_cfg = @utils.RestoreCheckpointConfig() - checkpoint_period = 100 - shard_id = 0 - num_shards = 1 - -partitioning.PjitPartitioner: - num_partitions = 1 - logical_axis_rules = @partitioning.standard_logical_axis_rules() - -utils.DatasetConfig: - mixture_or_task_name = @infer_script.create_task_from_tfexample_file() - task_feature_lengths = %FEATURE_LENGTHS - split = 'infer' - batch_size = 32 - shuffle = False - seed = 0 - pack = False - -infer_script.create_task_from_tfexample_file: - paths = %TF_EXAMPLE_FILE_PATHS - file_type = %TF_EXAMPLE_FILE_TYPE - inputs_key = 'inputs' - targets_key = None - features = {'inputs': @inputs/seqio.Feature(), 'targets': @outputs/seqio.Feature()} - -# Plumbing to extract the vocabulary directly from MODEL. This is needed to -# tokenize the features from the TFExample we aren't provided with vocabularies -# via a Task. -inputs/seqio.Feature.vocabulary = @models.get_input_vocabulary() -models.get_input_vocabulary.model = %MODEL -outputs/seqio.Feature.vocabulary = @models.get_output_vocabulary() -models.get_output_vocabulary.model = %MODEL - -utils.RestoreCheckpointConfig: - mode = 'specific' - path = %CHECKPOINT_PATH - dtype = 'bfloat16' - diff --git a/t5x-main/t5x/configs/runs/precompile.gin b/t5x-main/t5x/configs/runs/precompile.gin deleted file mode 100644 index 787d7d9a0f0a107bcd59ba9e5d83442fd042182b..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/configs/runs/precompile.gin +++ /dev/null @@ -1,59 +0,0 @@ -# Defaults for precompile mode in main.py. -# -# You must also include a binding for MODEL. -# -# Required to be set: -# -# - MIXTURE_OR_TASK_NAME -# - TASK_FEATURE_LENGTHS -# - TRAIN_STEPS -# - MODEL_DIR: # automatically set when using xm_launch -# -# Commonly overridden options: -# -# - USE_CACHED_TASKS -# - BATCH_SIZE -# - PjitPartitioner.num_partitions -from __gin__ import dynamic_registration - -import __main__ as train_script -import seqio -from t5x import gin_utils -from t5x import partitioning -from t5x import utils -from t5x import trainer - -MODEL_DIR = %gin.REQUIRED -MIXTURE_OR_TASK_NAME = %gin.REQUIRED -TASK_FEATURE_LENGTHS = %gin.REQUIRED - - -# Commonly overridden -USE_CACHED_TASKS = True -BATCH_SIZE = 128 - -# None always uses faster, hardware RNG -RANDOM_SEED = None - -train_script.precompile: - model = %MODEL # imported from separate gin file - model_dir = %MODEL_DIR - train_dataset_cfg = @train/utils.DatasetConfig() - partitioner = @partitioning.PjitPartitioner() - random_seed = %RANDOM_SEED - -partitioning.PjitPartitioner: - num_partitions = 1 - model_parallel_submesh = None - backend = "tpu" - logical_axis_rules = @partitioning.standard_logical_axis_rules() - -train/utils.DatasetConfig: - mixture_or_task_name = %MIXTURE_OR_TASK_NAME - task_feature_lengths = %TASK_FEATURE_LENGTHS - split = 'train' - batch_size = %BATCH_SIZE - shuffle = True - seed = None # use a new seed each run/restart - use_cached = %USE_CACHED_TASKS - pack = True diff --git a/t5x-main/t5x/configs/runs/pretrain.gin b/t5x-main/t5x/configs/runs/pretrain.gin deleted file mode 100644 index 3161d117b2e88b2a169496fac4a4c7934953014b..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/configs/runs/pretrain.gin +++ /dev/null @@ -1,110 +0,0 @@ -# Defaults for pretraining with train.py. -# -# -# You must also include a binding for MODEL. -# -# Required to be set: -# -# - MIXTURE_OR_TASK_NAME -# - TASK_FEATURE_LENGTHS -# - TRAIN_STEPS -# - MODEL_DIR: # automatically set when using xm_launch -# -# Commonly overridden options: -# -# - train/DatasetConfig.batch_size -# - train_eval/DatasetConfig.batch_size -# - PjitPartitioner.num_partitions -# - Trainer.num_microbatches -# - DROPOUT_RATE -from __gin__ import dynamic_registration - -import __main__ as train_script -from t5x import gin_utils -from t5x import partitioning -from t5x import utils -from t5x import trainer - -MIXTURE_OR_TASK_NAME = %gin.REQUIRED -TASK_FEATURE_LENGTHS = %gin.REQUIRED -TRAIN_STEPS = %gin.REQUIRED -MODEL_DIR = %gin.REQUIRED -BATCH_SIZE = 128 -USE_CACHED_TASKS = True - -# DEPRECATED: Import the this module in your gin file. -MIXTURE_OR_TASK_MODULE = None -SHUFFLE_TRAIN_EXAMPLES = True - -# HW RNG is faster than SW, but has limited determinism. -# Most notably it is not deterministic across different -# submeshes. -USE_HARDWARE_RNG = False -# None always uses faster, hardware RNG -RANDOM_SEED = None -TRAIN_STEPS_RELATIVE = None - -# Can be overridden with `train.*`.` -train_script.train: - model = %MODEL # imported from separate gin file - model_dir = %MODEL_DIR - train_dataset_cfg = @train/utils.DatasetConfig() - train_eval_dataset_cfg = @train_eval/utils.DatasetConfig() - infer_eval_dataset_cfg = None - checkpoint_cfg = @utils.CheckpointConfig() - partitioner = @partitioning.PjitPartitioner() - trainer_cls = @trainer.Trainer - total_steps = %TRAIN_STEPS - eval_steps = 20 - eval_period = 1000 - relative_steps = %TRAIN_STEPS_RELATIVE - random_seed = %RANDOM_SEED - use_hardware_rng = %USE_HARDWARE_RNG - summarize_config_fn = @gin_utils.summarize_gin_config - -partitioning.PjitPartitioner: - num_partitions = 1 - model_parallel_submesh = None - logical_axis_rules = @partitioning.standard_logical_axis_rules() - -train/utils.DatasetConfig: - mixture_or_task_name = %MIXTURE_OR_TASK_NAME - task_feature_lengths = %TASK_FEATURE_LENGTHS - split = 'train' - batch_size = %BATCH_SIZE - shuffle = %SHUFFLE_TRAIN_EXAMPLES - seed = None # use a new seed each run/restart - use_cached = %USE_CACHED_TASKS - pack = True - module = %MIXTURE_OR_TASK_MODULE - -train_eval/utils.DatasetConfig: - mixture_or_task_name = %MIXTURE_OR_TASK_NAME - task_feature_lengths = %TASK_FEATURE_LENGTHS - split = 'validation' - batch_size = %BATCH_SIZE - shuffle = False - seed = 42 - use_cached = %USE_CACHED_TASKS - pack = True - module = %MIXTURE_OR_TASK_MODULE - -utils.CheckpointConfig: - restore = @utils.RestoreCheckpointConfig() - save = @utils.SaveCheckpointConfig() -utils.RestoreCheckpointConfig: - path = [] # initialize from scratch -utils.SaveCheckpointConfig: - period = 1000 - dtype = 'float32' - keep = None # keep all checkpoints - save_dataset = False # don't checkpoint dataset state - -trainer.Trainer: - num_microbatches = None - learning_rate_fn = @utils.create_learning_rate_scheduler() - -utils.create_learning_rate_scheduler: - factors = 'constant * rsqrt_decay' - base_learning_rate = 1.0 - warmup_steps = 10000 # 10k to keep consistent with T5/MTF defaults. diff --git a/t5x-main/t5x/contrib/__init__.py b/t5x-main/t5x/contrib/__init__.py deleted file mode 100644 index 089ca7ab19058f576564841f21aa3588f519b996..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""This empty file is needed for packaging the contrib modules.""" diff --git a/t5x-main/t5x/contrib/calm/README.md b/t5x-main/t5x/contrib/calm/README.md deleted file mode 100644 index c2c053deb6fdec41d92860f8b156132367aa4d5f..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/calm/README.md +++ /dev/null @@ -1,24 +0,0 @@ -# Confident Adaptive Language Modeling (CALM) - -This repository contains overrides and configs for running the CALM T5 model in T5X, introduced in the NeurIPS 2022 paper: [Confident Adaptive Language Modeling](https://arxiv.org/abs/2207.07061). - -CALM skips Transformer decoder layers when generating text by early exiting based on calibrated confidence measures. - -This model should be paired with the Flaxformer [calm_t5](https://github.com/google/flaxformer/tree/main/flaxformer/architectures/calm_t5) architecture. - -## Reference -When referring to this model, please cite this paper: - -``` -@inproceedings{Schuster2022CALM, - title={Confident Adaptive Language Modeling}, - author={Tal Schuster and Adam Fisch and Jai Gupta and Mostafa Dehghani and Dara Bahri and Vinh Quang Tran and Yi Tay and Donald Metzler}, - booktitle={Advances in Neural Information Processing Systems (NeurIPS)}, - url = {https://arxiv.org/abs/2207.07061}, - year={2022}, -} -``` - - -## Note -This is not an officially supported Google product. \ No newline at end of file diff --git a/t5x-main/t5x/contrib/calm/__init__.py b/t5x-main/t5x/contrib/calm/__init__.py deleted file mode 100644 index 3283cf1d013c02acbf1cb2b253e51ec54ac7b796..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/calm/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Import API modules.""" - -import t5x.contrib.calm.decoding -import t5x.contrib.calm.models -# Version number. -from t5x.version import __version__ diff --git a/t5x-main/t5x/contrib/calm/decoding.py b/t5x-main/t5x/contrib/calm/decoding.py deleted file mode 100644 index 93376c1bc9367e110bee5eab42172c3fa20e5ab4..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/calm/decoding.py +++ /dev/null @@ -1,683 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Fast decoding routines that log stats for early exiting.""" - -import functools -from typing import Callable, Mapping, Optional, Tuple, Union - -import flax -import jax -from jax import lax -from jax import random -import jax.numpy as jnp -from t5x import binary_search -from t5x import decoding -from t5x.decoding import _is_tracer -from t5x.decoding import DecodingState -from t5x.decoding import MIN_TEMPERATURE -from t5x.decoding import NEG_INF - - -# ------------------------------------------------------------------------------ -# Temperature Sampling -# ------------------------------------------------------------------------------ - - -@flax.struct.dataclass -class SamplingLoopState: - """Holds sampling state data. - - Attributes: - cur_index: [batch_size] array position of the sampling loop in the length - dimension. - sequences: [batch_size * num_decodes, max_decode_len] array of current - sampled sequence prefixes. - cache: any mapping of arrays, e.g. flax attention cache. - cur_token: [batch_size, num_decodes] single timestep slice containing - current tokens. - ended: [batch_size, num_decodes] binary array marking completed sequences. - rng: Jax PRNGKey - log_prob: [batch_size, num_decodes] array of log probs for each sequence. - confidences: [batch_size, max_decode_len] array of confidence scores per - token measured at the last computed decoder layer. - exits: [batch_size, max_decode_len] array recording the number of decoder - layers used (until exiting) per token. - """ - - cur_index: jnp.ndarray - sequences: jnp.ndarray - cache: Mapping[str, jnp.ndarray] - cur_token: jnp.ndarray - ended: jnp.ndarray - rng: jnp.ndarray - log_prob: jnp.ndarray - confidences: jnp.ndarray - exits: jnp.ndarray - - -def temperature_sample( - inputs: jnp.ndarray, - cache: Mapping[str, jnp.ndarray], - tokens_to_logits: Callable[ - [DecodingState], - Tuple[jnp.ndarray, Mapping[str, jnp.ndarray], jnp.ndarray, jnp.ndarray], - ], - eos_id: int, - decode_rng: Optional[jnp.ndarray] = None, - num_decodes: int = 1, - temperature: Union[float, jnp.ndarray] = 1.0, - topk: int = 1, - topp: float = 0.0, - cache_offset: int = 0, - initial_index: Optional[jnp.ndarray] = None, - max_decode_steps: Optional[Union[int, jnp.ndarray]] = None, - max_decode_steps_hard_limit: Optional[int] = None, - rescale_log_probs: bool = True, - state_callback_fn: Optional[ - Callable[[SamplingLoopState], SamplingLoopState] - ] = None, - logit_callback_fn: Optional[ - Callable[[jnp.ndarray, SamplingLoopState], jnp.ndarray] - ] = None, -) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]]: - """Temperature sampling for language model generation. - - The temperature sampling is performed `num_decodes` times in a vectorized - manner by expanding the batch dimension. This is similar to how beam search - expands the batch dimension to process each batch element with multiple beams. - - This function dynamically updates the `inputs` array by sampling from the - model logits, which is provided by `tokens_to_logits` callable. The input - sequences are expanded at the end, populated and sliced by dropping the first - position. - - If `inputs` has non-zero entries, those values are not modified, i.e., - the sampled values for those positions are discarded. This simulates the - teacher forcing on the prefix positions. - - There are a few important observations related to this function. - - 1. The `inputs` is assumed to be a non-packed sequence. - - 2. If `initial_index=None`, then `inputs`[:, 0] is ignored. We will use 0 as a - BOS token to start the generation. This inherently assumes that `inputs` is - already shifted to the right by one position. If `initial_index=an_array`, - the token values at `inputs`[:, initial_index] are used as the token to - start the generation. - - 3. The loop index, i, is a vector of shape [batch_size]. When beginning - generation from scratch, each value will always have the same value. When - beginning with a partially filled cache, the loop index of different - elements can differ, via providing a value for `initial_index`. - - 3. Unless all batch elements generated the eos_id before reaching the end, we - always make `max_decode_len = inputs.shape[1]` number of calls to - `tokens_to_logits` when decoding from scratch and - `max_decode_len - jnp.minimum(initial_index)` number of calls when starting - from a partially filled cache. - - 4. Let `output` be the output sequences, i.e.,`sequences`[:, 1:]. Then - `output`[:, j] are the tokens generated when the while loop counter `i = - j`. Therefore, we generate the last token when `i = max_decode_len - 1` - and exit the while loop as all `i`s are incremented to `max_decode_len`. - - 5. Once `eos_id = 1` is generated, the subsequent predictions are all replaced - by padding token 0. - - 6. When using a partially filled cache, different batch elements can have - different lengths. This means an input that has a longer input will have - fewer steps until its `i` value reaches `max_decode_len` than an input with - a shorter input. We keep these longer examples alive, doing busy work - continually overwriting a new garbage token at the end of the sequence - until shorter examples finish. - - 7. When using a partially filled cache, providing a value for `initial_index`, - the attention cache index should be a vector of [batch_size]. - - We show three examples to illustrate how this function works. In addition to - input and output of the function, we also show two intermediate values: - `expanded_prompt_inputs` and `final_sequences`. Also for simplicity, the - examples are limited to `num_decodes = 1` usage and the `num_decodes` - dimension is omitted. - - ``` - Example 1: - inputs = [0, 5, 6, 1, 0] - expanded_prompt_inputs = [0, 5, 6, 1, 0, 0] - final_sequences = [0, 5, 6, 1, a, b] # before slicing. - output = [5, 6, 1, a, b] - where `a` is prediction while taking 1 as input and `b` is prediction while - taking `a` as input. - - Example 2 (early stopping): - inputs = [[0, 5, 1, 0, 0, 0, 0], - [0, 8, 0, 0, 0, 0, 0] - expanded_prompt_inputs = [[0, 5, 1, 0, 0, 0, 0, 0], - [0, 8, 0, 0, 0, 0, 0, 0] - final_sequences = [[0, 5, 1, a, b, c=1, 0, 0], - [0, 8, d, e, f=1, g=0, 0, 0]] - output = [[5, 1, a, b, c=1, 0, 0], - [8, d, e, f=1, g=0, 0, 0]] - - In this example, there are two sequences. Let's look at sequence 0. The - first generated token is `a`, which is in turn used to generate `b`. - Finally, `c = 1` is generated with the input `b`. Then the loop terminates - early because 1 is the `eos_id`. - - Now consider sequence 1. The when `f = 1` was generated, it is considered - done. Since sequence 0 is not done at this point, the next prediction, i.e., - `g` is zerod out. This continues until the end. - - Example 3 (prefilled cache): - inputs = [[0, 5, 2, 6, 1, 0], - [0, 8, 1, 0, 0, 0]] - expanded_prompt_inputs = [[0, 5, 2, 6, 1, 0, 0, 0], - [0, 8, 1, 0, 0, 0, 0, 0]] - max_decode_length = 6 - i = [4, 2] - input_tokens = [[1], - [1]] - output_tokens = [[a], - [b]] - expanded_prompt_inputs = [[0, 5, 2, 6, 1, a, 0, 0], - [0, 8, 1, b, 0, 0, 0, 0]] - i = [5, 3] - input_tokens = [[a], - [b]] - output_tokens = [[c], - [d]] - expanded_prompt_inputs = [[0, 5, 2, 6, 1, a, c, 0], - [0, 8, 1, b, d, 0, 0, 0]] - i = [6, 4] - input_tokens = [[c], - [d]] - output_tokens = [[y], - [e]] - expanded_prompt_inputs = [[0, 5, 2, 6, 1, a, c, y], - [0, 8, 1, b, d, e, 0, 0]] - i = [6, 5] - input_tokens = [[z], - [e]] - output_tokens = [[z], - [f]] - expanded_prompt_inputs = [[0, 5, 2, 6, 1, a, c, z], - [0, 8, 1, b, d, e, f, 0]] - i = [6, 6] - exit - outputs = [[5, 2, 6, 1, a, c], - [8, 1, b, d, e, f]] - - In this example, there are two sequences with different input lengths. Thus - the two caches had been filled to different positions. As we decode, the - first sequence hits the max decode length before the second. In order to - avoid prematurely ending decoding for the second sequence, the first - sequence continually overwrites the final token. - - Example 4 (prefilled cache and max decode steps): - inputs = [[0, 2, 0, 0, 0, 0, 0, 0], - [0, 3, 4, 0, 0, 0, 0, 0]] - expanded_prompt_inputs = [[0, 2, 0, 0, 0, 0, 0, 0, 0, 0] - [0, 3, 4, 0, 0, 0, 0, 0, 0, 0]] - initial_indices = [1, 2] - max_decode_step = 2 - - Then `max_decode_len = [3, 4]`. - i = [1, 2] - input_tokens = [[2], - [4]] - output_tokens = [[a], - [b]] - expanded_prompt_inputs = [[0, 2, a, 0, 0, 0, 0, 0, 0, 0] - [0, 3, 4, b, 0, 0, 0, 0, 0, 0]] - i = [2, 3]] - input_tokens = [[a], - [b]] - output_tokens = [[c], - [d]] - expanded_prompt_inputs = [[0, 2, a, c, 0, 0, 0, 0, 0, 0] - [0, 3, 4, b, d, 0, 0, 0, 0, 0]] - This is the last while loop iteration with i == max_decode_len - 1. - outputs = [[2, a, c, 0, 0, 0, 0, 0] - [3, 4, b, d, 0, 0, 0, 0]] - ``` - - Args: - inputs: array: [batch_size, max_decode_len] int32 sequence of tokens. - cache: flax attention cache. - tokens_to_logits: fast autoregressive decoder function taking single token - slices and cache and returning next-token logits and updated cache. - eos_id: int: end-of-sentence token for target vocabulary. - decode_rng: JAX PRNGKey. - num_decodes: number of decoded sequences to be returned. - temperature: float: sampling temperature factor. As it approaches zero this - becomes equivalent to greedy sampling. - topk: integer: if nonzero only use the top-k logits to sample next token, if - zero don't use any cutoff and sample from full logits over vocabulary. - topp: float: if nonzero only use the smallest number of logits whose - cumulative sum of probs adds up to (at least) topp. Will raise ValueError - if it's nonzero when topk is nonzero. - cache_offset: axis offset for cache, arising from scanned layers. - initial_index: Optional[array]: [batch_size] int32 a vector of loop indexes - to start decoding at. - max_decode_steps: int: an optional maximum number of decoding steps. If - None, it will decode until the full input shape `inputs.shape[1]` is - filled. max_decode_steps begins counting after the prompt, so it will - decode at most len(prompt) + max_decode_steps tokens. - max_decode_steps_hard_limit: int: an optional fixed hard limit on - max_decode_steps. If this is set (not None and > 0), and max_decode_steps - is also set, then max_decode_steps will be clipped to this limit. The - value max_decode_steps can be an ndarray, but max_decode_steps_hard_limit - must be a Python integer or None. - rescale_log_probs: bool: whether to apply temperature, topp, and topk - rescaling to the log probs which are returned. If True, the log_probs will - include these transformations (for example, with topk=1, all log_probs - will be identically 0.0). If False, the log_probs will not be affected, - and topk/topp/temperature will not affect sequence probabilities. - state_callback_fn: Function that modifies the sampling loop state before - each step. This can be used to manipulate any part of the state either on - the accelerator or on the host using host callback. The function should - take a SamplingLoopState as argument, and it returns the updated state. - See `decoding_test.py` for an example usage. - logit_callback_fn: Function that modifies the logits before each temperature - sampling step. The function should take arguments (logits, state) and it - should return the modified logits. See `decoding_test.py` for an example - usage. - - Returns: - A tuple (decodes, log_prob) where `decodes` is sampled sequences with shape - [batch_size, num_decodes, max_decode_len] sorted by `log_prob`, which is log - probability of each of the sampled sequences. - """ - if decode_rng is None: - decode_rng = jax.random.PRNGKey(0) - - if ( - max_decode_steps_hard_limit is not None - and max_decode_steps_hard_limit > 0 - and max_decode_steps is not None - ): - max_decode_steps = jnp.minimum( - max_decode_steps, max_decode_steps_hard_limit - ) - - if num_decodes > 1: - # [batch, len] -> [batch * num_decodes, len] - expanded_inputs = decoding.flat_batch_beam_expand(inputs, num_decodes) - expanded_cache = decoding.cache_map( - functools.partial( - decoding.flat_batch_beam_expand, - beam_size=num_decodes, - offset=cache_offset, - ), - cache, - # When we start with a prefilled cache, the cache index is no longer a - # scalar that will broadcast across multiple decodes, it is a vector and - # needs to be updated to handle the multiple decodes. - apply_to_index=initial_index is not None, - ) - if initial_index is not None: - initial_index = decoding.flat_batch_beam_expand( - initial_index, num_decodes - ) - else: - expanded_inputs = inputs - expanded_cache = cache - - # expanded_decodes: [batch * num_decodes, len] - # expanded_log_prob: [batch * num_decodes] - # expanded_exits: [batch * num_decodes, len] - # expanded_confidences: [batch * num_decodes, len] - expanded_decodes, expanded_log_prob, expanded_exits, expanded_confidences = ( - _temperature_sample_single_trial( - expanded_inputs, - expanded_cache, - tokens_to_logits, - eos_id, - decode_rng, - temperature, - topk, - topp, - initial_index=initial_index, - max_decode_steps=max_decode_steps, - rescale_log_probs=rescale_log_probs, - state_callback_fn=state_callback_fn, - logit_callback_fn=logit_callback_fn, - ) - ) - - batch_size = inputs.shape[0] - # [batch * num_decodes, len] -> [batch, num_decodes, len] - decodes = decoding.unflatten_beam_dim( - expanded_decodes, batch_size, num_decodes - ) - exits = decoding.unflatten_beam_dim(expanded_exits, batch_size, num_decodes) - confidences = decoding.unflatten_beam_dim( - expanded_confidences, batch_size, num_decodes - ) - # [batch * num_decodes] -> [batch, num_decodes] - log_prob = decoding.unflatten_beam_dim( - expanded_log_prob, batch_size, num_decodes - ) - - # Sort `decodes` and `log_prob` by increasing log probabilities of the sampled - # sequence. - # [batch, num_decodes, 1] - idxs = jnp.expand_dims(jnp.argsort(log_prob, axis=-1), axis=-1) - - # returns [batch, num_decodes, len], [batch, num_decodes] in sorted order. - sorted_decodes = jnp.take_along_axis(decodes, idxs, axis=1) - sorted_log_prob = jnp.take_along_axis( - log_prob, jnp.squeeze(idxs, axis=-1), axis=-1 - ) - sorted_exits = jnp.take_along_axis(exits, idxs, axis=1) - sorted_confidences = jnp.take_along_axis(confidences, idxs, axis=1) - - return sorted_decodes, (sorted_log_prob, sorted_exits, sorted_confidences) - - -def _temperature_sample_single_trial( - inputs: jnp.ndarray, - cache: Mapping[str, jnp.ndarray], - tokens_to_logits: Callable[ - [DecodingState], - Tuple[jnp.ndarray, Mapping[str, jnp.ndarray], jnp.ndarray, jnp.ndarray], - ], - eos_id: int, - prng_key: jnp.ndarray, - temperature: Union[float, jnp.ndarray] = 1.0, - topk: int = 20, - topp: Union[float, jnp.ndarray] = 0.0, - initial_index: Optional[jnp.ndarray] = None, - max_decode_steps: Optional[Union[int, jnp.ndarray]] = None, - rescale_log_probs: bool = True, - state_callback_fn: Optional[ - Callable[[SamplingLoopState], SamplingLoopState] - ] = None, - logit_callback_fn: Optional[ - Callable[[jnp.ndarray, SamplingLoopState], jnp.ndarray] - ] = None, -) -> jnp.ndarray: - """A helper function for `temperature_sample`.""" - - # We can check the values of topp and topk only if they are not dynamic. - if not _is_tracer(topp) and topp and topk: - raise ValueError('At most one of `topp` or `topk` may be non-zero.') - - batch_size, max_decode_len = inputs.shape - - if max_decode_steps is not None: - # We can check the max_decode_steps bounds only if it is not dynamic. - if not _is_tracer(max_decode_steps) and max_decode_steps > inputs.shape[1]: - raise ValueError('Cannot decode more steps than the sequence length.') - - # The number of decode steps required to process the prefix is the number - # of non-zero tokens, since inputs[0] == 0 is the BOS token. - # `max_decode_len[j]` is the number of non-padding tokens in the jth element - # of the returned sequences capped at `len(inputs)`, assuming that the - # early stop doesn't occur. This is true with or without - # `max_decode_steps`. - # When the while loop index `i` for the `j`th element `i[j] = - # max_decode_len[j] - 1`, the generated token populate sequences[i[j]+1]]. - # Since sequences[:, 0] is BOS token, the generated token is - # `max_decode_len[j]`th non-padding tokens and hence `j`th element is - # ended. - max_decode_len = jnp.sum(inputs != 0, axis=1) + max_decode_steps - max_decode_len = jnp.minimum(inputs.shape[1], max_decode_len) - - # In the case of starting generation from a non-zero index, it is possible for - # one batch element to reach `max_decode_len` number of decoding steps before - # another. In order to let the last element decoder all the way to - # `max_decode_len` number of steps, we add a final garbage token to the end of - # the sequences. Any element that has reached `max_decode_len` before the rest - # of the elements will continually overwrite this token until all elements - # finish. - # [batch, length+1] -> [batch, length+2] - extra_input_tokens = 2 - expanded_prompt_inputs = jnp.append( - inputs, - jnp.zeros((batch_size, extra_input_tokens), dtype=inputs.dtype), - axis=1, - ) - end_marker = jnp.array(eos_id) - - temperature = jnp.asarray(temperature) - - # Initialize sampling loop state. - # initial loop PRNGKey - rng0 = prng_key - # the per batch-item holding current token in loop. - if initial_index is None: - # the per batch-item loop position counter. - i0 = jnp.zeros((batch_size), dtype=jnp.int32) - # the per batch-item holding current token in loop. - token0 = jnp.zeros((batch_size, 1), dtype=jnp.int32) - else: - # the per batch-item loop position counter. - i0 = initial_index - # the per batch-item holding current token in loop. - # Select the token that the initial index is pointing to. - token0 = jnp.take_along_axis( - expanded_prompt_inputs, jnp.expand_dims(i0, axis=1), axis=1 - ) - # per batch-item state bit indicating if sentence has finished. - ended0 = jnp.zeros((batch_size, 1), dtype=jnp.bool_) - # (batch, length+2) array containing prefix prompt tokens for sampling loop - # as well as the generated output of newly sampled tokens. - sequences0 = expanded_prompt_inputs - log_prob0 = jnp.zeros((batch_size,), dtype=jnp.float32) - confidences0 = -1.0 * jnp.ones((batch_size, max_decode_len), jnp.float32) - exits0 = -1 * jnp.ones((batch_size, max_decode_len), jnp.int32) - sampling_loop_init_state = SamplingLoopState( - i0, - sequences0, - cache, - token0, - ended0, - rng0, - log_prob0, - confidences0, - exits0, - ) - # Initial eos count to be used to determine whether eos is "generated". Many - # inputs follow the format bos, inputs..., eos, targets..., eos. By counting - # the number of eos tokens we can detect when a new one is added, instead of - # just finding the one that probably ends the inputs. - # [batch, 1] - initial_eos_count = jnp.sum(sequences0 == end_marker, axis=-1, keepdims=True) - - def sampling_loop_cond_fn(state: SamplingLoopState) -> bool: - """Sampling loop termination condition.""" - # Have all sampled sequences reached an end marker? - # Different elements in the batch can be at different loop indices, if any - # of our examples are not at the end, keep going. - all_sequences_ended = jnp.all(state.ended) - return ~all_sequences_ended # pytype: disable=bad-return-type # jnp-type - - def sampling_loop_body_fn(state: SamplingLoopState) -> SamplingLoopState: - """Sampling loop state update.""" - - if state_callback_fn is not None: - state = state_callback_fn(state) - - # Split RNG for sampling. - rng1, rng2 = random.split(state.rng) - # Call fast-decoder model on current tokens to get next-position logits. - decoding_state = DecodingState( - cur_index=state.cur_index, - sequences=state.sequences[:, :-extra_input_tokens], - cur_token=state.cur_token, - cache=state.cache, - ) - confidences = state.confidences - exits = state.exits - logits, new_cache, conf, exit_layer = tokens_to_logits(decoding_state) - # Sample next token from logits. - - if logit_callback_fn is not None: - logits = logit_callback_fn(logits, state) - - def sample_logits_with_nonzero_temperature(logits): - scaled_logits = logits / jnp.maximum(temperature, MIN_TEMPERATURE) - if topk: - scaled_logits = binary_search.topk_mask(scaled_logits, topk, NEG_INF) # pytype: disable=wrong-arg-types # jax-ndarray - - # When topp is dynamic, we always use it since we cannot check - # non-zeroness (but it will have no effect if topp is 0.0). - if _is_tracer(topp) or topp: - scaled_logits = binary_search.topp_mask(scaled_logits, topp, NEG_INF) # pytype: disable=wrong-arg-types # jax-ndarray - - # [batch] - next_token = random.categorical(rng1, scaled_logits).astype(jnp.int32) - - # log probability of the current token conditioned on the previously - # sampled and prefix tokens. - # [batch, vocab] -> [batch, vocab] - if rescale_log_probs: - log_probs = jax.nn.log_softmax(scaled_logits) - else: - log_probs = jax.nn.log_softmax(logits) - # [batch, vocab] -> [batch] - next_log_prob = jnp.squeeze( - jnp.take_along_axis( - log_probs, jnp.expand_dims(next_token, axis=1), axis=-1 - ), - axis=-1, - ) - - return (next_token, next_log_prob) - - def sample_logits_with_zero_temperature(logits): - # For zero temperature, we always want the greedy output, regardless - # of the values of topk and topp. - - next_token = jnp.argmax(logits, -1).astype(jnp.int32) - - if rescale_log_probs: - next_log_prob = jnp.zeros_like(next_token, dtype=jnp.float32) - else: - log_probs = jax.nn.log_softmax(logits) - next_log_prob = jnp.squeeze( - jnp.take_along_axis( - log_probs, jnp.expand_dims(next_token, axis=1), axis=-1 - ), - axis=-1, - ) - - return (next_token, next_log_prob) - - # Perform sampling with temperature - (next_token, next_log_prob) = lax.cond( - temperature > MIN_TEMPERATURE, - sample_logits_with_nonzero_temperature, - sample_logits_with_zero_temperature, - logits, - ) - - # When different batch elements are at different points in the loop counter, - # it is possible that an element that started at a higher index will reach - # `max_decode_len` before other elements. When this happens we need to make - # sure this element continuous overwrites our new garbage collection index. - # Here we clamp `i` to `max_decode_len`. This will cause the a write to - # `max_decode_len + 1` which is the final index in `sequences`. Subsequent - # loop body executions will also get their value clamped causing continual - # overwriting of the final garbage position until all examples are finished. - i = jnp.minimum(state.cur_index, max_decode_len) - - # Only use sampled tokens if we're past provided prefix tokens. - # Select the next token from sequences. - # [batch] - next_input_token = jnp.squeeze( - jnp.take_along_axis( - state.sequences, jnp.expand_dims(i + 1, axis=1), axis=1 - ), - axis=1, - ) - # Check if the next token is padding (a target) or non-padding (an input). - # Mask will have `1` for targets and `0` for inputs. - out_of_prompt = next_input_token == 0 - # Select the sampled next token for targets and the actual next token for - # inputs (teacher forcing). - # [batch] - next_token = next_token * out_of_prompt + next_input_token * ~out_of_prompt - - # only add probability if outside prefix region - # [batch] -> [batch] - next_log_prob = state.log_prob + ( - next_log_prob * out_of_prompt - ) * jnp.squeeze(~state.ended, axis=-1).astype(jnp.int32) - - # [batch] -> [batch, 1] - next_token = jnp.expand_dims(next_token, axis=-1) - - # If end-marker reached for batch item, only emit padding tokens. - # [batch, 1] * [batch, 1] -> [batch, 1] - next_token_or_endpad = next_token * ~state.ended - # Add current sampled tokens to recorded sequences. - one_hot = jax.nn.one_hot( - i + 1, state.sequences.shape[1], dtype=state.sequences.dtype - ) - new_sequences = ( - state.sequences * (1 - one_hot) + next_token_or_endpad * one_hot - ) - # new_sequences = dynamic_update_vector_slice_in_dim(sequences, - # next_token_or_endpad, - # i + 1, - # 0) - # Count eos tokens in the sequences and compare to the initial count - # [batch, 1] - cur_eos_count = jnp.sum(new_sequences == end_marker, axis=-1, keepdims=True) - # [batch, 1] - - # Have we reached max decoding length? - # We generally index into sequences[:, i + 1], and sequences.shape[1] = - # max_decode_len + 2, therefore i == max_decode_len - 1 will write to - # sequences[-2] which is our last valid location. i == max_decode_len will - # write to sequences[-1] which is our garbage collection token. Thus `i` - # should be strictly less than max_decode_len. - has_additional_eos = cur_eos_count > initial_eos_count - ended = ( - state.ended - | has_additional_eos - | jnp.expand_dims(i >= max_decode_len - 1, axis=1) - ) - - new_conf = confidences.at[:, i].set(conf) - new_exits = exits.at[:, i].set(exit_layer) - - return SamplingLoopState( - i + 1, - new_sequences, - new_cache, - next_token_or_endpad, - ended, - rng2, - next_log_prob, - new_conf, - new_exits, - ) - - # Run sampling loop and collect final state. - final_state = lax.while_loop( - sampling_loop_cond_fn, sampling_loop_body_fn, sampling_loop_init_state - ) - - # Pick part of the state corresponding to the sampled sequences. - final_sequences = final_state.sequences - log_prob = final_state.log_prob - final_exits = final_state.exits - final_confidences = final_state.confidences - # Drop the first position because they are dummy bos tokens. Drop the new - # garbage collection token at the end too. - return final_sequences[:, 1:-1], log_prob, final_exits, final_confidences # pytype: disable=bad-return-type # jax-ndarray diff --git a/t5x-main/t5x/contrib/calm/models.py b/t5x-main/t5x/contrib/calm/models.py deleted file mode 100644 index a107dde3c22f3b1b70d83c1a7bb3fc37abf709a4..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/calm/models.py +++ /dev/null @@ -1,1035 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Models with CALM early exit functionality.""" - -import copy -import functools -from typing import Any, Callable, Mapping, MutableMapping, Optional, Tuple, Union - -import clu.metrics as clu_metrics -import flax -from flax import linen as nn -import jax -from jax import lax -import jax.numpy as jnp -import numpy as np -import seqio -from t5x import decoding -from t5x import losses -from t5x import metrics as metrics_lib -from t5x import models -from t5x import optimizers -from t5x.contrib.calm import decoding as calm_decoding -import tensorflow as tf -import typing_extensions - - -# Remove _ShardedDeviceArray when users of t5x have their types updated -_ShardedDeviceArray = Any -Array = Union[np.ndarray, jnp.ndarray, _ShardedDeviceArray, tf.Tensor] -MetricsMap = metrics_lib.MetricsMap -PyTree = Any - - - - -class TokensIdsToLogitsCallable(typing_extensions.Protocol): - """Token ids to logits mapping call signature.""" - - def __call__( - self, decoding_state: calm_decoding.DecodingState - ) -> Tuple[jnp.ndarray, Mapping[str, jnp.ndarray], jnp.ndarray, jnp.ndarray]: - """Performs forward pass to convert token ids to logits. - - Args: - decoding_state: Current decoding state, including current token ids and - cache. - - Returns: - logits: logits with a shape [batch_size, vocab_size]. - cache: An updated cache. - confidences: Float array of shape [batch_size, max_decode_len] with the - confidence values measured at the exit layer. - layres: Int array of shape [batch_size, max_decode_len] with the exited - layer per token. - """ - ... - - -class DecodeFnCallable(typing_extensions.Protocol): - """Decoding function call signature.""" - - def __call__( - self, - *, - inputs: jnp.ndarray, - cache: Mapping[str, jnp.ndarray], - tokens_to_logits: TokensIdsToLogitsCallable, - eos_id: int, - num_decodes: int, - decode_rng: Optional[jax.Array], - cache_offset: int, - **kwargs, - ) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]]: - """Decoding function interface. - - Args: - inputs: [batch_size, max_decode_len] int32 sequence of tokens, with non-0 - prefix tokens to be used as a forced prompt. - cache: flax attention cache. - tokens_to_logits: fast autoregressive decoder function taking single token - slices and cache and returning next-token logits and updated cache. - eos_id: end-of-sentence token for target vocabulary. - num_decodes: number of decoded sequences to be returned. - decode_rng: an optional JAX PRNG Key for stochastic sampling routines. - cache_offset: axis offset for cache, arising from scanned layers. - **kwargs: an optional kwargs. One common usecase of this is passing - decoding parameters at the callsite. - - Returns: - decodes: Array of sequences: [batch_size, num_decodes, max_decode_len]. - The `num_decodes` dimension is expected to be sorted by the `scores`, - i.e., `decodes[:, -1, :] has the highest scores among `num_decodes` - decoded sequences. - scores: Array of log likelihood scores: [batch_size, num_decodes] - confidences: Float array of shape [batch_size, max_decode_len] with the - confidence values measured at the exit layer. - layres: Int array of shape [batch_size, max_decode_len] with the exited - layer per token. - """ - ... - - -class EncoderDecoderModel(models.EncoderDecoderModel): - """Wrapper class for the models.Transformer nn.module. - - Incorporates CALM early exit functionalities. - """ - - def __init__( - self, - module: nn.Module, - input_vocabulary: seqio.Vocabulary, - output_vocabulary: seqio.Vocabulary, - optimizer_def: optimizers.OptimizerDefType, - decode_fn: DecodeFnCallable = calm_decoding.temperature_sample, - feature_converter_cls: Optional[ - Callable[..., seqio.FeatureConverter] - ] = None, - label_smoothing: float = 0.0, - z_loss: float = 0.0, - loss_normalizing_factor: Optional[float] = None, - apply_early_inference: bool = False, - decoder_layers: int = 12, - conf_threshold: float = 1.0, - min_exit: int = 0, - first_exit: int = 0, - exit_interval: int = 1, - aggregation_weights: int = 1, - oracle_tok_consistency: bool = False, - oracle_cache: bool = False, - oracle_tok_noisy_cache: bool = False, - conf_method: str = 'softmax_max', - train_meta_cls: bool = False, - geomlike_loss: bool = False, - position_adjusted_threshold: bool = False, - position_temp: int = 4, - ): - self.apply_early_inference = apply_early_inference - self.decoder_layers = decoder_layers - self.conf_threshold = conf_threshold - self.min_exit = min_exit - self.first_exit = first_exit - self.exit_interval = exit_interval - self.aggregation_weights = aggregation_weights - self.oracle_tok_consistency = oracle_tok_consistency - self.oracle_cache = oracle_cache - self.oracle_tok_noisy_cache = oracle_tok_noisy_cache - self.conf_method = conf_method - self.train_meta_cls = train_meta_cls - self.geomlike_loss = geomlike_loss - self.position_adjusted_threshold = position_adjusted_threshold - self.position_temp = position_temp - super().__init__( - module=module, - input_vocabulary=input_vocabulary, - output_vocabulary=output_vocabulary, - optimizer_def=optimizer_def, - decode_fn=decode_fn, - feature_converter_cls=feature_converter_cls, - label_smoothing=label_smoothing, - z_loss=z_loss, - loss_normalizing_factor=loss_normalizing_factor, - ) - - def get_pred_confidence( - self, # pytype: disable=annotation-type-mismatch # jax-ndarray - logits: jnp.ndarray = None, - prev_state: jnp.ndarray = None, - new_state: jnp.ndarray = None, - meta_score: jnp.ndarray = None, - ) -> jnp.ndarray: - """Computes the of decoder in its current prediction. - - The confidence function is determined by self.conf_method. - - Args: - logits: Array with last dimension holding the logits over the output - vocabulary. - prev_state: Hidden state from previous layer. - new_state: Hidden state of current layer. - meta_score: The confidence score of an early-exit classifier. - - Returns: - confidence: Per example confidence scores. - """ - if self.conf_method == 'softmax_diff': - # Computes confidence by taking the difference between the top two softmax - # scores. This implementation can be slow due to sorting all logits. - assert logits is not None - logits_sorted = jnp.sort(logits, axis=-1)[..., ::-1] # sort descending. - sorted_probs = nn.softmax(logits_sorted, axis=-1) - return sorted_probs[..., 0] - sorted_probs[..., 1] - - if self.conf_method == 'softmax_diff_approx': - # A faster softmax approximate difference implementation. - assert logits is not None - probs = nn.softmax(logits, axis=-1) - top_2 = jax.lax.approx_max_k(probs, k=2)[0] - return top_2[..., 0] - top_2[..., 1] - - if self.conf_method == 'softmax_max': - # Computes confidence by taking the maximum softmax value. - assert logits is not None - return nn.softmax(logits, axis=-1).max(axis=-1) - - elif self.conf_method == 'state': - # Computes the confidence by the cosine similarity between the current - # hidden state and the previous one. - assert prev_state is not None and new_state is not None - conf = jnp.inner(prev_state, new_state) / ( - jnp.linalg.norm(prev_state) * jnp.linalg.norm(new_state) - ) - return conf.squeeze() - - elif self.conf_method == 'meta': - # Using scores from early-exit classifier (already given in input). - assert meta_score is not None - return meta_score - - else: - raise NotImplementedError( - f'Confidence method {self.conf_method} is not implemented.' - ) - - def loss_fn_meta_cls( - self, - params: PyTree, - batch: Mapping[str, jnp.ndarray], - dropout_rng: jnp.ndarray, - ) -> Tuple[jnp.ndarray, MetricsMap]: - """Loss function for the meta early exit classifier (meta_cls). - - Should also be used with 'return_all_logits' option for the decoder. Can be - used for a second training step where we freeze the rest of the model, - (non-meta_cls parameters). - - Args: - params: model parameters. - batch: a batch of inputs. - dropout_rng: rng to use for dropout, or None for deterministic mode. - - Returns: - loss: the loss computed for the given inputs and parameters. - metrics: a mapping of metrics computed for this batch. - """ - - loss_normalizing_factor: Optional[ - Union[float, int, str, losses.SpecialLossNormalizingFactor] - ] - (loss_normalizing_factor, weights) = ( - losses.get_loss_normalizing_factor_and_weights( - self._loss_normalizing_factor, batch - ) - ) - - all_logits = self._compute_logits(params, batch, dropout_rng) - assert isinstance( - all_logits, tuple - ), 'Verify that meta_cls was initialized in decoder.' - all_meta_logits = all_logits[1] - all_logits = all_logits[0] - - # Create meta labels based on consistency of intermediate prediction with - # the top prediction. - predictions = all_logits.argmax(-1) - top_pred = predictions[-1] - all_meta_labels = jnp.array(top_pred == predictions, dtype=jnp.int32) - - # Aggregate meta loss across layers. - ( - all_loss, - all_total_z_loss, - ) = ( - [], - [], - ) - for meta_logits, meta_labels in zip( - all_meta_logits[:-1], all_meta_labels[:-1] - ): - # Balance across the positive/ negative labels. - balanced_weights = weights.copy().astype(float) # pytype: disable=attribute-error # jnp-type - - pos_num = (meta_labels * weights == 1).sum() - neg_num = ((1 - meta_labels) * weights == 1).sum() - - pos_weight = 1 - (pos_num / (pos_num + neg_num)) - neg_weight = 1 - (neg_num / (pos_num + neg_num)) - balanced_weights = ( - weights * meta_labels * pos_weight - + weights * (1 - meta_labels) * neg_weight - ) - - # Compute layer loss. - all_loss_i, all_total_z_loss_i, _ = losses.compute_weighted_cross_entropy( - meta_logits, - targets=meta_labels, - label_smoothing=self._label_smoothing, - z_loss=self._z_loss, - weights=balanced_weights, - loss_normalizing_factor=loss_normalizing_factor, - ) - all_loss.append(all_loss_i) - all_total_z_loss.append(all_total_z_loss_i) - - loss = jnp.average(jnp.array(all_loss), 0) - total_z_loss = jnp.average(jnp.array(all_total_z_loss), 0) - - metrics = self._compute_metrics( - logits=all_logits[-1], - targets=batch['decoder_target_tokens'], - mask=weights, - loss=loss, - z_loss=total_z_loss, - ) - - # Meta metrics. - for i, (meta_logits, meta_labels) in enumerate( - zip(all_meta_logits[:-1], all_meta_labels[:-1]) - ): - meta_metrics = { - f'meta_accuracy/layer_{i}': clu_metrics.Accuracy.from_model_output( - logits=meta_logits, - labels=meta_labels.astype(jnp.int32), - mask=weights, - ), - f'meta_loss/layer_{i}': metrics_lib.AveragePerStep(total=all_loss[i]), - } - metrics.update(meta_metrics) - - return loss, metrics - - def loss_fn_meta_cls_geom_like( - self, - params: PyTree, - batch: Mapping[str, jnp.ndarray], - dropout_rng: jnp.ndarray, - ) -> Tuple[jnp.ndarray, MetricsMap]: - """Geometric-like loss function for the meta early exit classifier. - - This version follows the geometric-like (consistency-based) of - https://arxiv.org/pdf/1910.10073.pdf - that optimizes the meta cls as a sequence of decisions instead of - independent predcitions over layers. - - Should also be used with 'return_all_logits' option for the decoder. Can be - used for a second training step where we freeze the rest of the model, - (non-meta_cls parameters). - - - Args: - params: model parameters. - batch: a batch of inputs. - dropout_rng: rng to use for dropout, or None for deterministic mode. - - Returns: - loss: the loss computed for the given inputs and parameters. - metrics: a mapping of metrics computed for this batch. - """ - - loss_normalizing_factor: Optional[ - Union[float, int, str, losses.SpecialLossNormalizingFactor] - ] - (loss_normalizing_factor, weights) = ( - losses.get_loss_normalizing_factor_and_weights( - self._loss_normalizing_factor, batch - ) - ) - - all_logits = self._compute_logits(params, batch, dropout_rng) - assert isinstance( - all_logits, tuple - ), 'Verify that meta_cls was initialized in decoder.' - all_meta_logits = all_logits[1] - all_logits = all_logits[0] - - # Create meta labels based on consistency of intermediate prediction with - # the top prediction. - predictions = all_logits.argmax(-1) - top_pred = predictions[-1] - all_meta_labels = jnp.array(top_pred == predictions, dtype=jnp.int32) - - # Here, this is treated as an L-way classification task (L=decoder layers). - all_meta_labels_multiclass = all_meta_labels.argmax(0) - - # Geometric-like aggregation. - all_meta_scores = nn.log_softmax(all_meta_logits, axis=-1) - all_meta_scores_pos = all_meta_scores[..., 1] - all_meta_scores_neg = all_meta_scores[..., 0] - non_stop_probs = all_meta_scores_neg.cumsum(0) - all_meta_scores_neg - geom_like_probs = non_stop_probs + all_meta_scores_pos - geom_like_probs = jnp.moveaxis(geom_like_probs, 0, -1) - - loss, total_z_loss, _ = losses.compute_weighted_cross_entropy( - geom_like_probs, - targets=all_meta_labels_multiclass, - label_smoothing=self._label_smoothing, - z_loss=self._z_loss, - weights=weights, - loss_normalizing_factor=loss_normalizing_factor, - ) - - total_z_loss = 0.0 # hardcoded - - metrics = self._compute_metrics( # pytype: disable=wrong-arg-types # jax-ndarray - logits=all_logits[-1], - targets=batch['decoder_target_tokens'], - mask=weights, - loss=loss, - z_loss=total_z_loss, - ) - - # Meta metrics. - for i, (meta_logits, meta_labels) in enumerate( - zip(all_meta_logits[:-1], all_meta_labels[:-1]) - ): - meta_metrics = { - f'meta_accuracy/layer_{i}': clu_metrics.Accuracy.from_model_output( - logits=meta_logits, - labels=meta_labels.astype(jnp.int32), - mask=weights, - ), - } - meta_metrics.update({ - 'meta_accuracy/multiclass': clu_metrics.Accuracy.from_model_output( - logits=geom_like_probs, - labels=all_meta_labels_multiclass.astype(jnp.int32), - mask=weights, - ), - }) - metrics.update(meta_metrics) - - return loss, metrics - - def loss_fn( - self, - params: PyTree, - batch: Mapping[str, jnp.ndarray], - dropout_rng: jnp.ndarray, - ) -> Tuple[jnp.ndarray, MetricsMap]: - """Loss function for anytime predictions across model layers. - - Should be used with `return_all_logits` option for the decoder. Per-layer - loss is aggregated with a weighted sum, according `aggregation_weights`. - - if `train_meta_cls` is True, will call `loss_fn_meta_cls` instead. - - Args: - params: model parameters. - batch: a batch of inputs. - dropout_rng: rng to use for dropout, or None for deterministic mode. - - Returns: - loss: the loss computed for the given inputs and parameters. - metrics: a mapping of metrics computed for this batch. - """ - - if self.train_meta_cls: - if self.geomlike_loss: - return self.loss_fn_meta_cls_geom_like(params, batch, dropout_rng) - else: - return self.loss_fn_meta_cls(params, batch, dropout_rng) - - loss_normalizing_factor: Optional[ - Union[float, int, str, losses.SpecialLossNormalizingFactor] - ] - (loss_normalizing_factor, weights) = ( - losses.get_loss_normalizing_factor_and_weights( - self._loss_normalizing_factor, batch - ) - ) - - all_logits = self._compute_logits(params, batch, dropout_rng) - all_loss, all_total_z_loss = [], [] - for logits in all_logits: - all_loss_i, all_total_z_loss_i, _ = losses.compute_weighted_cross_entropy( - logits, - targets=batch['decoder_target_tokens'], - weights=weights, - label_smoothing=self._label_smoothing, - z_loss=self._z_loss, - loss_normalizing_factor=loss_normalizing_factor, - ) - all_loss.append(all_loss_i) - all_total_z_loss.append(all_total_z_loss_i) - - if self.aggregation_weights == -1: - # Geometric series with a=1, r=2. - avg_weights = jnp.geomspace(1, 2 ** (len(all_loss) - 1), len(all_loss)) - elif self.aggregation_weights == 0: - avg_weights = jnp.ones(len(all_loss)) - else: - avg_weights = jnp.arange( - 1, - self.aggregation_weights * len(all_loss) + 1, - step=self.aggregation_weights, - ) - loss = jnp.average(jnp.array(all_loss), 0, avg_weights) - total_z_loss = jnp.average(jnp.array(all_total_z_loss), 0, avg_weights) - - # Based on last logits. - metrics = self._compute_metrics( - logits=all_logits[-1], - targets=batch['decoder_target_tokens'], - mask=weights, - loss=loss, - z_loss=total_z_loss, - ) - - # Per layer metrics. - for i, logits in enumerate(all_logits[:-1]): - meta_metrics = { - f'accuracy_per_layer/layer_{i}': ( - clu_metrics.Accuracy.from_model_output( - logits=logits, - labels=batch['decoder_target_tokens'], - mask=weights, - ) - ), - f'loss_per_layer/layer_{i}': metrics_lib.AveragePerStep( - total=all_loss[i] - ), - } - metrics.update(meta_metrics) - - return loss, metrics - - def _compute_logits_from_slice_early_exit( - self, - decoding_state: calm_decoding.DecodingState, - params: PyTree, - encoded_inputs: jnp.ndarray, - raw_inputs: jnp.ndarray, - max_decode_length: int, - conf_threshold: float, - pos: int = 0, - ) -> Tuple[jnp.ndarray, Mapping[str, jnp.ndarray], jnp.ndarray, jnp.ndarray]: - """Token slice to logits from decoder model with early exit mechanism.""" - - num_layers = self.decoder_layers - flat_ids: jnp.ndarray = decoding_state.cur_token - flat_cache: Mapping[str, jnp.ndarray] = decoding_state.cache - - if self.oracle_tok_consistency or self.oracle_tok_noisy_cache: - # Compute the prediction of the full model for reference for the oracle. - oracle_flat_cache = copy.deepcopy(flat_cache) - oracle_flat_logits, _ = self.module.apply( - {'params': params, 'cache': oracle_flat_cache}, - encoded_inputs, - raw_inputs, # only needed for encoder padding mask - flat_ids, - flat_ids, - enable_dropout=False, - decode=True, - max_decode_length=max_decode_length, - start_idx=0, - end_idx=None, - return_prelogits=False, - mutable=['cache'], - method=self.module.decode, - ) - oracle_tok_pred = oracle_flat_logits.argmax() - - # Get the computation intervals (per layers) of the decoder between exits. - keep_inds = list(range(self.first_exit + 1, num_layers, self.exit_interval)) - comp_intervals = [(0, self.first_exit + 1)] + [ - (i, j) for i, j in zip(keep_inds, keep_inds[1:] + [num_layers]) - ] - - # First run the decoder but only up to the first exit. - decoder_hidden, new_vars = self.module.apply( - {'params': params, 'cache': flat_cache}, - encoded_inputs, - raw_inputs, # only needed for encoder padding mask - flat_ids, - flat_ids, - enable_dropout=False, - decode=True, - max_decode_length=max_decode_length, - start_idx=0, - end_idx=comp_intervals[0][1], - return_prelogits=True, - mutable=['cache'], - method=self.module.decode, - ) - - # If using meta_cls. - if isinstance(decoder_hidden, tuple): - meta_score = nn.softmax(decoder_hidden[1], axis=-1)[..., 1] - decoder_hidden = decoder_hidden[0] - else: - meta_score = None - - new_flat_cache = new_vars['cache'] - if 'softmax' in self.conf_method: - flat_logits = self.module.apply( - {'params': params, 'cache': flat_cache}, - decoder_hidden, - logit_mask=None, - enable_dropout=False, - method=self.module.compute_logits, - ) - else: - flat_logits = None - - if self.conf_method == 'state': - # Always skip the first 'exit' since previous state is missing. - conf = 0 - else: - conf = self.get_pred_confidence(logits=flat_logits, meta_score=meta_score) - - # Used to enable a positional argument (decoder_embedded_input) for switch. - def prt( - a, - b, - c, - d, - e, - f, - start_idx=0, - end_idx=None, - only_propagate_state=False, - **kwargs, - ): # pylint: disable=unused-argument - return self.module.apply( - a, - b, - c, - d, - e, - enable_dropout=False, - decode=True, - max_decode_length=max_decode_length, - start_idx=start_idx, - end_idx=end_idx, - decoder_embedded_input=f, - return_prelogits=True, - only_propagate_state=only_propagate_state, - mutable=['cache'], - method=self.module.decode, - ) - - # Segments of the model between exits, passed to lax.switch. - branches = [ - functools.partial( # pylint: disable=g-complex-comprehension - prt, - enable_dropout=False, - decode=True, - max_decode_length=max_decode_length, - start_idx=interval[0], - end_idx=interval[1], - mutable=['cache'], - method=self.module.decode, - ) - for interval in comp_intervals - ] - - # Switch branches for state propagation. Last branch has zero layers, to be - # used if no propagation is needed (i.e., didn't exit early). - state_prop_branches = [ - functools.partial( # pylint: disable=g-complex-comprehension - prt, - enable_dropout=False, - decode=True, - max_decode_length=max_decode_length, - start_idx=interval[0], - end_idx=None, - mutable=['cache'], - only_propagate_state=True, - method=self.module.decode, - ) - for interval in comp_intervals - ] + [ - functools.partial( # pylint: disable=g-complex-comprehension - prt, - enable_dropout=False, - decode=True, - max_decode_length=max_decode_length, - start_idx=comp_intervals[-1][1], - end_idx=comp_intervals[-1][1], # same idx (to just skip) - mutable=['cache'], - only_propagate_state=True, - method=self.module.decode, - ) - ] - - # TODO(talschuster) convert to named tuple. - init_state = ( - flat_logits, - decoder_hidden, - new_flat_cache, - conf, - 1, - comp_intervals[0][1], - meta_score, - ) - - # Runs a segment of the model. - def body_fn(state): - _, decoder_hidden, new_flat_cache, _, interval, layer, _ = state - - new_decoder_hidden, new_vars = lax.switch( - interval, - branches, - {'params': params, 'cache': new_flat_cache}, - encoded_inputs, - raw_inputs, # only needed for encoder padding mask - flat_ids, - flat_ids, - decoder_hidden, - ) - - # If using meta_cls. - if isinstance(new_decoder_hidden, tuple): - meta_score = nn.softmax(new_decoder_hidden[1], axis=-1)[..., 1] - new_decoder_hidden = new_decoder_hidden[0] - else: - meta_score = None - - if 'softmax' in self.conf_method: - cur_flat_logits = self.module.apply( - {'params': params, 'cache': new_flat_cache}, - new_decoder_hidden, - logit_mask=None, - enable_dropout=False, - method=self.module.compute_logits, - ) - new_flat_logits = cur_flat_logits - else: - new_flat_logits = None - - new_flat_cache = new_vars['cache'] - - new_conf = self.get_pred_confidence( - logits=new_flat_logits, - prev_state=decoder_hidden, - new_state=new_decoder_hidden, - meta_score=meta_score, - ) - - layer = lax.min(layer + self.exit_interval, num_layers) - return ( - new_flat_logits, - new_decoder_hidden, - new_flat_cache, - new_conf, - interval + 1, - layer, - meta_score, - ) - - # Stopping condition (loop continues until it's False). - def cond_fn(state): - if self.position_adjusted_threshold: - # Decays the confidence threshold with decoding time step. - correct_by_pos = ( - lambda i: conf_threshold - * jnp.exp( # pylint: disable=g-long-lambda - -self.position_temp * i / max_decode_length - ) - / 10 - + 9 * conf_threshold / 10 - ) - adjusted_threshold = correct_by_pos(jnp.min(pos)) - else: - adjusted_threshold = conf_threshold - - if self.oracle_tok_consistency: - # Oracle to exit first time predictiong is the same as top layer. - flat_logits, _, _, _, _, layer, _ = state - return (flat_logits.argmax() != oracle_tok_pred) & (layer < num_layers) - else: - # Continues until average batch confidence reaches the threshold, or - # until all layers were exhausted. Also, doesn't exit before min_exit. - _, _, _, conf, _, layer, _ = state - return ((jnp.min(conf) < adjusted_threshold) & (layer < num_layers)) | ( - layer < self.min_exit - ) - - ( - flat_logits, - new_decoder_hidden, - new_flat_cache, - conf, - interval, - layer, - new_meta_score, - ) = lax.while_loop(cond_fn, body_fn, init_state) - - if 'softmax' not in self.conf_method: - # Computes the softmax over the output vocabulary only after exiting. - flat_logits = self.module.apply( - {'params': params, 'cache': new_flat_cache}, - new_decoder_hidden, - logit_mask=None, - enable_dropout=False, - method=self.module.compute_logits, - ) - - if self.oracle_cache: - # Run the rest of the layers to compute the real cache (oracle setting). - def cond_fn_complete_run(state): - _, _, _, _, _, layer, _ = state - return layer < num_layers - 1 - - post_exit_state = ( - flat_logits, - new_decoder_hidden, - new_flat_cache, - conf, - interval, - layer - self.exit_interval, - new_meta_score, - ) - _, _, new_flat_cache, _, _, _, _ = lax.while_loop( - cond_fn_complete_run, body_fn, post_exit_state - ) - else: - # If some decoding layers were skipped, we want to pass the state - # from the last computed layer to all the upstream skipped layers. This - # way, the next tokens, if continued to higher layers, can attend back to - # previous token states. Here, the hidden-state is passed to the higher - # layers and let each layer compute its own key-value projections. - _, new_vars = lax.switch( - interval, - state_prop_branches, - {'params': params, 'cache': new_flat_cache}, - encoded_inputs, - raw_inputs, # only needed for encoder padding mask - flat_ids, - flat_ids, - decoder_hidden, - ) - new_flat_cache = new_vars['cache'] - - if self.oracle_tok_noisy_cache: - # Takes the logits of the top layer, but uses the hidden state from the - # exited layer. - oracle_flat_logits = jnp.squeeze(oracle_flat_logits, axis=1) - return oracle_flat_logits, new_flat_cache, conf, layer - - # Remove sequence length dimension since it's always 1 during decoding. - flat_logits = jnp.squeeze(flat_logits, axis=1) - - return flat_logits, new_flat_cache, conf, layer - - def predict_batch_with_aux( - self, - params: PyTree, - batch: Mapping[str, jnp.ndarray], - rng: Optional[jax.Array] = None, - decoder_params: Optional[MutableMapping[str, Any]] = None, - return_all_decodes: bool = False, - num_decodes: int = 1, - prompt_with_targets: bool = False, - ) -> Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]: - """Predict with fast decoding beam search on a batch. - - Here we refer to "parameters" for values that can be compiled into the - model dynamically, as opposed to static configuration settings that require - a recompile. For example, the model weights and the decoder brevity-penalty - are parameters and can be modified without requiring a recompile. The number - of layers, the batch size and the decoder beam size are configuration - options that require recompilation if changed. - - This method can be used with a customizable decoding function as long as it - follows the signature of `DecodeFnCallable`. In order to provide a unified - interface for the decoding functions, we use a generic names. For example, a - beam size is a concept unique to beam search. Conceptually, it corresponds - to the number of sequences returned by the beam search. Therefore, the - generic argument `num_decodes` corresponds to the beam size if - `self._decode_fn` is a beam search. For temperature sampling, `num_decodes` - corresponds to the number of independent sequences to be sampled. Typically - `num_decodes = 1` is used for temperature sampling. - - If `return_all_decodes = True`, the return tuple contains the predictions - with a shape [batch, num_decodes, max_decode_len] and the scores (i.e., log - probability of the generated sequence) with a shape [batch, num_decodes]. - - If `return_all_decodes = False`, the return tuple contains the predictions - with a shape [batch, max_decode_len] and the scores with a shape [batch]. - - `decoder_params` can be used to pass dynamic configurations to - `self.decode_fn`. An example usage is to pass different random seed (i.e., - `jax.random.PRNGKey(seed)` with different `seed` value). This can be done by - setting `decoder_params['decode_rng'] = jax.random.PRNGKey(seed)`. - - If `prompt_with_targets = True`, then `decoder_prompt_inputs` is initialized - from the batch's `decoder_input_tokens`. The EOS is stripped to avoid - decoding to stop after the prompt by matching to `output_vocabulary.eos_id`. - - Args: - params: model parameters. - batch: a batch of inputs. - rng: an optional RNG key to use during prediction, which is passed as - 'decode_rng' to the decoding function. - decoder_params: additional (model-independent) parameters for the decoder. - return_all_decodes: whether to return the entire beam or just the top-1. - num_decodes: the number of beams to use in beam search. - prompt_with_targets: Whether the force decode decoder_inputs. - - Returns: - A tuple containing: - the batch of predictions, with the entire beam if requested - an auxiliary dictionary of decoder scores - """ - # Prepare zeroed-out autoregressive cache. - # [batch, input_len] - inputs = batch['encoder_input_tokens'] - # [batch, target_len] - target_shape = batch['decoder_input_tokens'].shape - target_type = batch['decoder_input_tokens'].dtype - _, variables_with_cache = self.module.apply( - {'params': params}, - jnp.ones(inputs.shape, inputs.dtype), - jnp.ones(target_shape, target_type), - jnp.ones(target_shape, target_type), - decode=True, - enable_dropout=False, - mutable=['cache'], - ) - - cache = variables_with_cache['cache'] - - # Prepare transformer fast-decoder call for beam search: for beam search, we - # need to set up our decoder model to handle a batch size equal to - # batch_size * num_decodes, where each batch item's data is expanded - # in-place rather than tiled. - # i.e. if we denote each batch element subtensor as el[n]: - # [el0, el1, el2] --> beamsize=2 --> [el0,el0,el1,el1,el2,el2] - # [batch * num_decodes, input_len, emb_dim] - encoded_inputs = decoding.flat_batch_beam_expand( - self.module.apply( - {'params': params}, - inputs, - enable_dropout=False, - method=self.module.encode, - ), - num_decodes, - ) - - # [batch * num_decodes, input_len] - raw_inputs = decoding.flat_batch_beam_expand(inputs, num_decodes) - - if self.apply_early_inference: - tokens_ids_to_logits = functools.partial( - self._compute_logits_from_slice_early_exit, - params=params, - encoded_inputs=encoded_inputs, - raw_inputs=raw_inputs, - max_decode_length=target_shape[1], - conf_threshold=self.conf_threshold, - ) - else: - tokens_ids_to_logits = functools.partial( - self._compute_logits_from_slice, - params=params, - encoded_inputs=encoded_inputs, - raw_inputs=raw_inputs, - max_decode_length=target_shape[1], - ) - - if decoder_params is None: - decoder_params = {} - if rng is not None: - if decoder_params.get('decode_rng') is not None: - raise ValueError( - f'Got RNG both from the `rng` argument ({rng}) and' - " `decoder_params['decode_rng']`" - f' ({decoder_params["decode_rng"]}). Please specify one or the' - ' other.' - ) - decoder_params['decode_rng'] = rng - - # `decoder_prompt_inputs` is initialized from the batch's - # `decoder_input_tokens`. The EOS is stripped to avoid decoding to stop - # after the prompt by matching to `output_vocabulary.eos_id`. - # These inputs are ignored by the beam search decode fn. - if prompt_with_targets: - decoder_prompt_inputs = batch['decoder_input_tokens'] - decoder_prompt_inputs = decoder_prompt_inputs * ( - decoder_prompt_inputs != self.output_vocabulary.eos_id - ) - else: - decoder_prompt_inputs = jnp.zeros_like(batch['decoder_input_tokens']) - - # Using the above-defined single-step decoder function, run a - # beam search over possible sequences given input encoding. - # decodes: [batch, num_decodes, max_decode_len + 1] - # scores: [batch, num_decodes] - scanned = hasattr(self.module, 'scan_layers') and self.module.scan_layers - - if 'eos_id' not in decoder_params: - decoder_params['eos_id'] = self.output_vocabulary.eos_id - decodes, scores = self._decode_fn( - inputs=decoder_prompt_inputs, - cache=cache, - tokens_to_logits=tokens_ids_to_logits, - num_decodes=num_decodes, - cache_offset=1 if scanned else 0, - **decoder_params, - ) - - # TODO(talschuster) make the decode func return a general dict. - if self.apply_early_inference: - scores, exits, confidences = scores - else: - exits, confidences = [], [] - - # Beam search returns [n_batch, n_beam, n_length] with beam dimension sorted - # in increasing order of log-probability. - # Return the highest scoring beam sequence. - if return_all_decodes: - return decodes, { # pytype: disable=bad-return-type # jax-ndarray - 'scores': scores, - 'exits': exits, - 'confidences': confidences, - } - else: - return decodes[:, -1, :], { - 'scores': scores[:, -1], - 'exits': exits[:, -1, :], - 'confidences': confidences[:, -1, :], - } diff --git a/t5x-main/t5x/contrib/gpu/AUTHORS b/t5x-main/t5x/contrib/gpu/AUTHORS deleted file mode 100644 index a36017acded892d43f2dc1ff2efeb1268808473d..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/AUTHORS +++ /dev/null @@ -1,8 +0,0 @@ -# This is the list of T5x's significant contributors. -# -# This does not necessarily list everyone who has contributed code, -# especially since many employees of one corporation may be contributing. -# To see the full list of contributors, see the revision history in -# source control. -Google LLC -NVIDIA Corporation diff --git a/t5x-main/t5x/contrib/gpu/Dockerfile b/t5x-main/t5x/contrib/gpu/Dockerfile deleted file mode 100644 index 4ab560e012bf28bfdcce204f989e4be4f8ef9c2f..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/Dockerfile +++ /dev/null @@ -1,16 +0,0 @@ -ARG FROM_IMAGE_NAME=nvcr.io/nvidia/tensorflow:22.08-tf2-py3 -FROM ${FROM_IMAGE_NAME} - -# Install the latest jax -RUN pip install jax[cuda]==0.4.1 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html - -# setup directory paths for T5x -ENV TFDS_DATA_DIR=/t5x_home/datasets/ -ENV T5X_DIR=/t5x_home/ -ENV T5X_WORKSPACE_DIR=/t5x_home/workspace -ENV PYTHONPATH=/t5x_home/ -WORKDIR /t5x_home - -# install the requirements for T5x -COPY . . -RUN pip install -e '.[gpu]' diff --git a/t5x-main/t5x/contrib/gpu/NOTICE b/t5x-main/t5x/contrib/gpu/NOTICE deleted file mode 100644 index b425b0b33f028fef0d4be7c56e3614de7ff3cf21..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/NOTICE +++ /dev/null @@ -1,5 +0,0 @@ -T5X in JAX - -This repository includes source code (in "t5x/tfds_pile.py") from: -* https://github.com/EleutherAI/the-pile - diff --git a/t5x-main/t5x/contrib/gpu/README.md b/t5x-main/t5x/contrib/gpu/README.md deleted file mode 100644 index 6e7cc57d2b0a12e6dd331a94bd1c2b7bb543906b..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/README.md +++ /dev/null @@ -1,90 +0,0 @@ -# GPU Scripts - -# Warning! -An updated version of T5x with optimized GPU performance and new features, including FP8 with [Transformer Engine](https://github.com/NVIDIA/TransformerEngine) and H100 support can be found here: [NVIDIA Rosetta](https://github.com/NVIDIA/JAX-Toolbox/tree/main/rosetta/rosetta/projects/t5x). ------ -**NVIDIA no longer recommends using this repository and won't be updating it further.** ------ - -The [t5x/contrib/gpu/scripts_gpu](scripts_gpu) directory contains scripts optimized for GPU usage. - -To get all dependencies for the Pile dataset, install with the `gpu` extra: -```bash -pip install '.[gpu]' -``` - -## Building the container -The Dockerfile in `t5x/contrib/gpu` given will build a container with all gpu/pile dependencies. It can be built with `t5x/contrib/gpu/docker/build.sh ` - -## Running interactively -Note: this should only be done with singlenode jobs and/or for downloading the pile. Use `t5x/contrib/gpu/docker/interactive_pull_and_launch.sh`. This takes arguments for the URL to pull a container from and the location of the dataset directory to mount. For example: - -`t5x/contrib/gpu/docker/interactive_pull_and_launch.sh [URL] /my/dataset/dir` - -## Downloading The Pile -Run `download_the_pile.py` to download the pile. It will download to the directory set in the environment variable: `TFDS_DATA_DIR`. After that, set the `TFDS_DATA_DIR` to the same directory in your scripts to use. - -## Single Node runs -Pretraining and Finetuning can be done with `singlenode_*.sh`. These will build a T5X model with the Adam optimizer and relevant parameters. These will allow multi-gpu on one host. - -## Multi Node runs -For a SLURM+pyxis cluster, `example*.sub` files provide example slurm submit files (edit with your details), which call `multiprocess*.sh` to execute training. You can add a binding script in the `.sub` file for your cluster, or remove it entirely (dropping some throughput) - -## Convergence -For our Pile convergence runs, we used a Global batch size of 2304 for XXL and 2048 for all other models, where GBS is defined as #GPUs * BS/GPU / Tensor Parallel(TP). Below are example (tested) hardware topologies on NVIDIA DGX A100 (8x A100 80G) nodes. - -| size | #GPUs | TP | BS / GPU | Sequences/Sec | Estimated Walltime | MNLI 2.0 - matched | SQuAD v1.1 (EM/F1) | Convergence Log | -| ---- | ----- | ----- | -------- | ------------- | ------------------ | ------------------ | ------------------ | --------------- | -| small| 8 | 1 | 256 | ~3168 | 7.48 days | 83.06% | 78.33 / 86.63 | [log](https://tensorboard.dev/experiment/lWnHal7PRnOLeZuewyWVxQ/#scalars&_smoothingWeight=0) | -| large| 64 | 1 | 32 | ~3886 | 6.10 days | 90.50% | 87.31 / 94.04 | [log](https://tensorboard.dev/experiment/aOxJBIvTQBeTJ8XGXxaL6Q/#scalars&_smoothingWeight=0) | -| xl | 256 | 1 | 8 | ~3652 | 6.49 days | 91.15% | 89.36 / 95.29 | [log](https://tensorboard.dev/experiment/vuRoEYgkRgWiEtbvgxlOqw/#scalars&_smoothingWeight=0) | -| xxl | 512 | 8 | 36 | ~1346 | 19.81 days | N/A(partial run) | N/A(partial run) | N/A(partial run)| - -Note: Convergence (as shown in log) was not necessarily done with the hardware topology listed, but the listed topology is tested. Estimated Walltime is calculated assuming full throughput (seq/sec) continuously. In practice, there are compilation overheads at the beginning of each run/restart(in cluster settings) + checkpointing overheads (if any). - -(More perf improvements coming soon!) - -Other hyperparameters are specified in the associated pile `gin` files in the `contrib/gpu/t5/t5_1_1/examples` directory. - -## Pretraining run commands - -### Singlenode -small: - -`t5x/contrib/gpu/t5/scripts_gpu/singlenode_pretrain_pile.sh small bfloat16 8 256 {LOGDIR - create before running} {MODEL_DIR} {GRADIENT_ACCUMULATION (1 by default)}` - -Finetuning: -MNLI v2: -`t5x/contrib/gpu/t5/scripts_gpu/singlenode_ft_frompile.sh mnli2 small bfloat16 8 256 {LOGDIR - create before running} {MODEL_DIR(to restore pretrained checkpoint from)} {GRADIENT_ACCUMULATION}` - - -### Multinode -Arguments are as such: - -`sbatch -N {NODE_CT} t5x/contrib/gpu/t5/scripts_gpu/example_slurm_pretrain_pile.sub {MODEL_SIZE} {MODEL_PREC} {GPU/NODE} {BS/GPU} {MODEL_DIR} {GRADIENT_ACCUMULATION} {TENSOR_PARALLEL}` - -small: - -`sbatch -N 1 t5x/contrib/gpu/t5/scripts_gpu/example_slurm_pretrain_pile.sub small bfloat16 8 256 {MODEL_DIR} 1 1` - -large: - -`sbatch -N 8 t5x/contrib/gpu/t5/scripts_gpu/example_slurm_pretrain_pile.sub large bfloat16 8 32 {MODEL_DIR} 1 1` - -xl: - -`sbatch -N 32 t5x/contrib/gpu/t5/scripts_gpu/example_slurm_pretrain_pile.sub xl bfloat16 8 8 {MODEL_DIR} 1 1` - -Finetuning commands simply change the script and have an additional `{FT_TASK}` as the first argument (along with relevant hyperparameter changes). Your `MODEL_DIR` should contain the pretrained checkpoint to restore from. - -MNLI v2: - -`sbatch -N {NODE_CT} t5x/contrib/gpu/t5/scripts_gpu/example_slurm_ft_frompile.sub mnli2 {MODEL_SIZE} {MODEL_PREC} {GPU/NODE} {BS/GPU} {MODEL_DIR} {GRADIENT_ACCUMULATION} {TENSOR_PARALLEL}` - -SQuAD v1.1 - -`sbatch -N {NODE_CT} t5x/contrib/gpu/t5/scripts_gpu/example_slurm_ft_frompile.sub squad1 {MODEL_SIZE} {MODEL_PREC} {GPU/NODE} {BS/GPU} {MODEL_DIR} {GRADIENT_ACCUMULATION} {TENSOR_PARALLEL}` - -On all finetuning runs, we use a Global Batch Size of 128 with bfloat16 precision. - -WARNING: Finetuning is configured by default to save every checkpoint and delete none (to avoid accidentally deleting your pretrained checkpoint). Watch your disk space! This behavior can be changed in `t5x/configs/runs/finetune_{TASK}.gin`, however this puts the pretrained checkpoint at risk unless backed up. diff --git a/t5x-main/t5x/contrib/gpu/__init__.py b/t5x-main/t5x/contrib/gpu/__init__.py deleted file mode 100644 index 8b137891791fe96927ad78e64b0aad7bded08bdc..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/__init__.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/t5x-main/t5x/contrib/gpu/docker/build.sh b/t5x-main/t5x/contrib/gpu/docker/build.sh deleted file mode 100644 index 441b5a3ab2760de19233e609a3fdb278a502fabd..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/docker/build.sh +++ /dev/null @@ -1,30 +0,0 @@ -#!/bin/bash - -# Copyright 2022 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -echo "Ensure you run this script from the top-level directory of the repo" - -CONTAINER="t5x" -if [ $# -eq 1 ] -then - echo $1 - CONTAINER=$1 -else - echo "Usage: bash build " - exit -fi - -# building container here -docker build -t $CONTAINER . -f t5x/contrib/gpu/Dockerfile diff --git a/t5x-main/t5x/contrib/gpu/docker/interactive_pull_and_launch.sh b/t5x-main/t5x/contrib/gpu/docker/interactive_pull_and_launch.sh deleted file mode 100644 index 77a4b54f725317654e13a70a791d56ea2c3ff959..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/docker/interactive_pull_and_launch.sh +++ /dev/null @@ -1,30 +0,0 @@ -#!/bin/bash - -# Copyright 2022 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -set -x - -CONTAINER=${1} -echo $CONTAINER -docker pull $CONTAINER - -DATASET_PATH=${2} - -## !! Uncomment this to add a custom path to workspace dir !!## -## By default `.../T5X/t5x/workspace` is selected -# WORKSPACE_PATH= - -nvidia-docker run -ti --net=host --ipc=host -v ${PWD}:/t5x_home -v ${DATASET_PATH}:/t5x_home/datasets -v ${WORKSPACE_PATH:-${PWD}/workspace}:/t5x_home/workspace --privileged $CONTAINER /bin/bash -set +x diff --git a/t5x-main/t5x/contrib/gpu/scripts_gpu/__init__.py b/t5x-main/t5x/contrib/gpu/scripts_gpu/__init__.py deleted file mode 100644 index bbb4a3a6ee82a0b91acf6aaa18b62e2d389a9a61..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/scripts_gpu/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright 2022-2023 The T5x Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""This empty file is needed to be recognized as a package by the setuptools.""" diff --git a/t5x-main/t5x/contrib/gpu/scripts_gpu/download_the_pile.py b/t5x-main/t5x/contrib/gpu/scripts_gpu/download_the_pile.py deleted file mode 100644 index 2007c4a0b43e800833b43116fdcd8d1eea54bf89..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/scripts_gpu/download_the_pile.py +++ /dev/null @@ -1,5 +0,0 @@ -import t5x.contrib.gpu.scripts_gpu.seqio_tasks -import tensorflow_datasets as tfds - -# This will download 'ThePile' to TFDS_DATA_DIR (environment variable). -ds = tfds.load('ThePile') diff --git a/t5x-main/t5x/contrib/gpu/scripts_gpu/dummy_wikipedia_config/dummy_wikipedia_seqio.py b/t5x-main/t5x/contrib/gpu/scripts_gpu/dummy_wikipedia_config/dummy_wikipedia_seqio.py deleted file mode 100644 index da67767928de9a2d7ddc20aceae070e23dee919e..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/scripts_gpu/dummy_wikipedia_config/dummy_wikipedia_seqio.py +++ /dev/null @@ -1,53 +0,0 @@ -# Copyright 2022 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import functools - -import seqio -import t5.data -from t5.data import preprocessors - -TaskRegistry = seqio.TaskRegistry - -DEFAULT_OUTPUT_FEATURES = { - "inputs": seqio.Feature( - vocabulary=seqio.SentencePieceVocabulary( - sentencepiece_model_file="gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model", - ), - add_eos=True, - required=False), - "targets": seqio.Feature( - vocabulary=seqio.SentencePieceVocabulary( - sentencepiece_model_file="gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model", - ), - add_eos=True) -} - -# ================================ Wikipedia =================================== -TaskRegistry.add( - "wikipedia_dummy", - source=seqio.TfdsDataSource(tfds_name="wikipedia/20190301.als:1.0.0"), - preprocessors=[ - functools.partial( - preprocessors.rekey, key_map={ - "inputs": None, - "targets": "text" - }), - seqio.preprocessors.tokenize, - seqio.CacheDatasetPlaceholder(), - preprocessors.unsupervised, - seqio.preprocessors.append_eos_after_trim, - ], - output_features=DEFAULT_OUTPUT_FEATURES, - metric_fns=[]) diff --git a/t5x-main/t5x/contrib/gpu/scripts_gpu/dummy_wikipedia_config/setup.py b/t5x-main/t5x/contrib/gpu/scripts_gpu/dummy_wikipedia_config/setup.py deleted file mode 100644 index c78b07bea2cfbb4ba922b2003d0bdc5887a9f414..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/scripts_gpu/dummy_wikipedia_config/setup.py +++ /dev/null @@ -1,26 +0,0 @@ -#!/usr/bin/env python - -# -*- coding: utf-8 -*- - -# Copyright 2022 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from distutils.core import setup - -setup(name='dummy_wikipedia_seqio', - version='0.0.1', - description='Dummy Wikipedia Seqio Task', - author='The T5X Authors', - install_requires=['t5'], - ) diff --git a/t5x-main/t5x/contrib/gpu/scripts_gpu/dummy_wikipedia_config/small_pretrain_dummy_wikipedia.gin b/t5x-main/t5x/contrib/gpu/scripts_gpu/dummy_wikipedia_config/small_pretrain_dummy_wikipedia.gin deleted file mode 100644 index 8ece3c641b4856860a39d206fb98c0ab9356341c..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/scripts_gpu/dummy_wikipedia_config/small_pretrain_dummy_wikipedia.gin +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright 2022 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __gin__ import dynamic_registration - -from t5x import partitioning - -from t5x.examples.t5 import network - -include "t5x/contrib/gpu/t5/t5_1_1/small.gin" -include "t5x/contrib/gpu/t5/t5_1_1/adamw_opt.gin" -include "t5x/contrib/gpu/t5/configs/runs/pretrain_pile.gin" - -# Register necessary SeqIO Tasks/Mixtures. -import t5.data.mixtures -# Register Dummy Wikipedia Seqio Task (needed for benchmarking) -import dummy_wikipedia_seqio - -MIXTURE_OR_TASK_NAME = "wikipedia_dummy" -TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 128} -TRAIN_STEPS = 100 -DROPOUT_RATE = 0.1 -BATCH_SIZE = 256 -USE_CACHED_TASKS=False - -partitioning.PjitPartitioner: - num_partitions=1 diff --git a/t5x-main/t5x/contrib/gpu/scripts_gpu/example_slurm_ft_frompile.sub b/t5x-main/t5x/contrib/gpu/scripts_gpu/example_slurm_ft_frompile.sub deleted file mode 100644 index 966ea69287a1a4a9e07cf867e4bdba5f0778bf7a..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/scripts_gpu/example_slurm_ft_frompile.sub +++ /dev/null @@ -1,88 +0,0 @@ -#!/bin/bash -#SBATCH -A example # slurm account -#SBATCH -p partition # slurm partition name -#SBATCH -N 1 # number of nodes -#SBATCH -t 04:00:00 # wall time -#SBATCH -J "t5x:train" # slurm job name -#SBATCH --exclusive # exclusive node access -#SBATCH --mem=0 # all mem avail -#SBATCH --mail-type=FAIL # only send email on failure -#SBATCH --overcommit -#SBATCH --dependency=singleton # tells slurm to run only one job with the same job name at a time -set -x - -# Copyright 2022 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# File system and volume glue code -#------------------------------------------------------------------------------- -# << CHANGE ! >> -SLURM_ACCOUNT='example' -USERID='exampleperson' - -# << CHANGE ! >> -CONTAINER="" # Add link to your built container - -# << CHANGE ! >> -BASE_T5X_DIR="...../t5x_git" # path to your clone of the repo -BASE_TFDS_DATA_DIR="" # path to tfds data directory -BASE_T5X_WORKSPACE_DIR="${BASE_T5X_DIR}/workspace" # path to where outputs will be dumped - -# Default env variables for paths required by t5x training scripts -TFDS_DATA_DIR=/t5x_home/datasets/ -T5X_DIR=/t5x_home/ -T5X_WORKSPACE_DIR=/t5x_home/workspace - -# Add the T5x/JAX specific mounts -MOUNTS="--container-mounts=$BASE_T5X_DIR:/$T5X_DIR,$BASE_TFDS_DATA_DIR:/$TFDS_DATA_DIR,$BASE_T5X_WORKSPACE_DIR:$T5X_WORKSPACE_DIR" - -# Add T5x/JAX specific exports -EXPORTS="--export=ALL,TFDS_DATA_DIR=${TFDS_DATA_DIR},T5X_DIR=${T5X_DIR},T5X_WORKSPACE_DIR=${T5X_WORKSPACE_DIR}" -#------------------------------------------------------------------------------- - -# Command line arguments needed by the underlying scripts -TASK=$1 # mnli2 or squad1, add others with corresponding gin files -T5_SIZE=$2 # small, base, large, xl, xxl -PREC="$3" # bfloat16, float32 -GPUS_PER_NODE=$4 # usually 8 -BSIZE_PER_GPU=$5 # local batch size/gpu -MODEL_DIR_LOCAL=$6 # directory to save checkpoints and config dump to -NUM_MICROBATCHES=$7 # number of gradient accumulation steps - -NUM_GPUS=$(( GPUS_PER_NODE * SLURM_JOB_NUM_NODES )) - -# redirect both stdout and stderr in the same file for ease of analysis -OUTDIR="outputs/multinode/${TASK}_t5_${T5_SIZE}-prec_${PREC}-nodes_${SLURM_JOB_NUM_NODES}-gpus_${NUM_GPUS}-bs_${BSIZE_PER_GPU}-sl_${SL}" - -OUTFILE="${BASE_T5X_WORKSPACE_DIR}/${OUTDIR}/output-%j-%n.txt" - -LOGDIR="${T5X_WORKSPACE_DIR}/${OUTDIR}" - -# << CHANGE ! >> -# You can add binding to the command below with the following line (after nvidia-smi). Remove the '&&' on the next bash line. -# && bash <>/bind.sh --cpu=exclusive --ib=single -- \ -read -r -d '' cmd <> -SLURM_ACCOUNT='example' -USERID='exampleperson' - -# << CHANGE ! >> -CONTAINER="" # Add link to your built container - -# << CHANGE ! >> -BASE_T5X_DIR="...../t5x_git" # path to your clone of the repo -BASE_TFDS_DATA_DIR="" # path to tfds data directory -BASE_T5X_WORKSPACE_DIR="${BASE_T5X_DIR}/workspace" # path to where outputs will be dumped - -# Default env variables for paths required by t5x training scripts -TFDS_DATA_DIR=/t5x_home/datasets/ -T5X_DIR=/t5x_home/ -T5X_WORKSPACE_DIR=/t5x_home/workspace - -# Add the T5x/JAX specific mounts -MOUNTS="--container-mounts=$BASE_T5X_DIR:/$T5X_DIR,$BASE_TFDS_DATA_DIR:/$TFDS_DATA_DIR,$BASE_T5X_WORKSPACE_DIR:$T5X_WORKSPACE_DIR" - -# Add T5x/JAX specific exports -EXPORTS="--export=ALL,TFDS_DATA_DIR=${TFDS_DATA_DIR},T5X_DIR=${T5X_DIR},T5X_WORKSPACE_DIR=${T5X_WORKSPACE_DIR}" -#------------------------------------------------------------------------------- - -# Command line arguments needed by the underlying scripts -T5_SIZE=$1 # small, base, large, xl, xxl -PREC="$2" # bfloat16, float32 -GPUS_PER_NODE=$3 # usually 8 -BSIZE_PER_GPU=$4 # local batch size/gpu -MODEL_DIR_LOCAL=$5 # directory to save checkpoints and config dump to -NUM_MICROBATCHES=$6 # number of gradient accumulation steps -MP=$7 # tensor parallel count - -NUM_GPUS=$(( GPUS_PER_NODE * SLURM_JOB_NUM_NODES )) - -# << CHANGE ! >> -# You can add binding to the command below with the following line (after nvidia-smi). Remove the '&&' on the next bash line. -# && bash <>/bind.sh --cpu=exclusive --ib=single -- \ -read -r -d '' cmd < \ - ${LOG_DIR}/ft_${T5_SIZE}_gpu_${NUM_GPUS}_${PREC}_gbs_${BSIZE}.log diff --git a/t5x-main/t5x/contrib/gpu/scripts_gpu/singlenode_pretrain_pile.sh b/t5x-main/t5x/contrib/gpu/scripts_gpu/singlenode_pretrain_pile.sh deleted file mode 100644 index c82322a2ced343252b0bda53b0f4e9a59fa3a0c5..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/scripts_gpu/singlenode_pretrain_pile.sh +++ /dev/null @@ -1,53 +0,0 @@ -#! /bin/bash -# A script for single-node pile pretraining - -# Copyright 2022 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -set -x - -TFDS_DATA_DIR="/t5x_home/datasets/" # << CHANGE !>> -T5X_DIR=${PWD} - -# Arguments -T5_SIZE=$1 # Model size (small, base, large) -PREC="$2" # Precision (float32, float16, bfloat16) -NUM_GPUS=$3 # Number of GPUs (1, 2, 4, 8) -BSIZE_PER_GPU=$4 # Batch size per GPU (varies with model size) -LOG_DIR=$5 # Output log directory -MODEL_DIR_LOCAL=${6:-"model_dir"} -MODEL_DIR=${PWD}/${MODEL_DIR_LOCAL} -NUM_MICROBATCHES=${7:-0} - -echo $MODEL_DIR - -echo "Please make sure ${NUM_GPUS} is the number of visible CUDA devices you have" - -# Setting XLA flags -export XLA_FLAGS="--xla_gpu_simplify_all_fp_conversions --xla_gpu_all_reduce_combine_threshold_bytes=136314880 ${XLA_FLAGS}" - -# Global batch size -BSIZE=$(( NUM_GPUS * BSIZE_PER_GPU )) - -rm -rf "${MODEL_DIR}/*" -python3 -u ${T5X_DIR}/t5x/train.py \ - --gin_file="t5x/contrib/gpu/t5/t5_1_1/examples/${T5_SIZE}_pile_pretrain.gin" \ - --gin.MODEL_DIR=\"${MODEL_DIR}\" \ - --gin.network.T5Config.dtype=\"${PREC}\" \ - --tfds_data_dir=${TFDS_DATA_DIR} \ - --gin.train/utils.DatasetConfig.batch_size=${BSIZE} \ - --gin.trainer.Trainer.num_microbatches=${NUM_MICROBATCHES} \ - --gin.train_eval/utils.DatasetConfig.batch_size=${BSIZE} \ - --gin.infer_eval/utils.DatasetConfig.batch_size=${BSIZE} &> \ - ${LOG_DIR}/${T5_SIZE}_gpu_${NUM_GPUS}_${PREC}_gbs_${BSIZE}.log diff --git a/t5x-main/t5x/contrib/gpu/scripts_gpu/tfds_pile.py b/t5x-main/t5x/contrib/gpu/scripts_gpu/tfds_pile.py deleted file mode 100644 index 77ef0d96a9c2203d8ece052b72ee4b277b278bf6..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/scripts_gpu/tfds_pile.py +++ /dev/null @@ -1,189 +0,0 @@ -""" -From: https://github.com/EleutherAI/the-pile -the_pile dataset -""" - -import tensorflow_datasets as tfds -import tensorflow as tf -import io -import zstandard -import jsonlines -import os -import time -from itertools import chain -""" -Tips for Colab - Change _PILE_SPLITS below to increments of 8 to allow downloading and storing in GCS -After every 8 parts, tfds will flush the tempfiles from local and it will be cached on GCS, allowing reuse -preventing th need to redownload again. Example below - -_download: Skipping download of http://eaidata.bmk.sh/data/pile/train/26.jsonl.zst: File cached in gs://your_bucket/datasets/cached/downloads/eaidata.bmk.sh_pile_train_26.jsonlCue2aNl9cxodxAvl9vIacuexGWYSoJAt4Rpcy19pqds.zst -_download: Skipping download of http://eaidata.bmk.sh/data/pile/train/27.jsonl.zst: File cached in gs://your_bucket/datasets/cached/downloads/eaidata.bmk.sh_pile_train_27.jsonlt8W_PLYeC4bZeaNMqMhe0-lhS3ijPL7RjvILWsMZlhQ.zst -_download: Downloading http://eaidata.bmk.sh/data/pile/train/28.jsonl.zst into gs://your_bucket/datasets/cached/downloads/eaidata.bmk.sh_pile_train_28.jsonl7Fj9nvI6std-e0H2ScxDKMpTWEC8iJMI8OT2vxLw2A4.zst.tmp.576c9ac11d30419b8ea8f30a5157ee53... -_download: Downloading http://eaidata.bmk.sh/data/pile/train/29.jsonl.zst into gs://your_bucket/datasets/cached/downloads/eaidata.bmk.sh_pile_train_29.jsonl1syFpl-ESnwk__9_6Xrj_OO5mRxpmaxQG7bZ_5d2sZc.zst.tmp.2f7f6afb86d74e988dcdb71d59b0d3f2... - - -Use tfds.disable_progress_bar() to prevent javascript issues -This uses pysimdjson for faster parsing of json. The entire dataset should be completed in around 12 hours on Colab. - -""" - -_USAGE_EXAMPLE = """ -This can be run in a script or in a notebook. - -_GCS_BUCKET = 'gs://your_gcs_bucket/path' - -import os -os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = '/path/to/adc.json' # if building to store in GCS -os.environ['TFDS_DATA_DIR'] = _GCS_BUCKET - -import tensorflow_datasets as tfds -from the_pile import tfds_pile -from transformers import GPT2TokenizerFast - -tokenizer = GPT2TokenizerFast.from_pretrained('gpt2') -tokenizer.add_special_tokens({'pad_token': '<|padding|>'}) - -def simple_tokenization(item): - return tokenizer.encode(item['text'], return_tensors='tf') - -tfds.disable_progress_bar() # optional - will help with colab since tqdm breaks often - -ds = tfds.load(name="ThePile", try_gcs=True) - -# Have not tested below -ds.map(simple_tokenization, num_parallel_calls=tf.data.experimental.AUTOTUNE) -# or -ds.map(lambda item: simple_tokenization(item), num_parallel_calls=tf.data.experimental.AUTOTUNE) - -""" - -try: - import simdjson as json -except ImportError: - print('Installing simdjson library') - os.system('pip install -q pysimdjson') - import simdjson as json - parser = json.Parser() - -parser = json.Parser() -_DESCRIPTION = """ -The Pile is a large, diverse, open source language modelling data set -that consists of many smaller datasets combined together. -The objective is to obtain text from as many modalities as possible to -ensure that models trained using The Pile will have much broader generalization abilities. -We are currently developing Version 1, with an ultimate goal of 1 TiB of English text. -After the completion of Version 1, our next goal is a fully-multilingual, 10TiB text dataset. -""" - -_CITATION = """ -""" -_DATASET_MODES = ["lm"] - -_PILE_URL = 'https://the-eye.eu/public/AI/pile/train/{}.jsonl.zst' -_PILE_SPLITS = 30 - -_URLS = { - 'the_pile': { - 'train': [ - _PILE_URL.format(str(i).zfill(2)) for i in range(_PILE_SPLITS) - ], - 'test': 'https://the-eye.eu/public/AI/pile/test.jsonl.zst', - 'validation': 'https://the-eye.eu/public/AI/pile/val.jsonl.zst', - } -} - -_VERSION = tfds.core.Version('1.0.0') -_RELEASE_NOTES = { - '1.0.0': 'Initial release.', -} - -_NAME = 'the_pile' -_FILE_FORMAT = 'jsonlines' - - -def json_parser(x): - global parser - try: - line = parser.parse(x).as_dict() - return line - except ValueError: - return x - - -class PileReader: - - def __init__(self, filenames, para_joiner='\n\n'): - if not isinstance(filenames, list): - filenames = [filenames] - self.filenames = filenames - self.para_joiner = para_joiner - - def _read_fn(self, filename): - print(filename) - with tf.io.gfile.GFile(filename, 'rb+') as f: - cctx = zstandard.ZstdDecompressor() - reader_stream = io.BufferedReader(cctx.stream_reader(f)) - reader = jsonlines.Reader(reader_stream, loads=json_parser) - print('reader made') - for item in reader: - result = dict() - if isinstance(item, str): - result['text'] = item - else: - text = item['text'] - if isinstance(text, list): - text = self.para_joiner.join(text) - result['text'] = text - - yield result - - def __iter__(self): - print(self.filenames) - #for item in chain.from_iterable([self._read_fn(filename) for filename in self.filenames]): - # return item - #for filename in self.filenames: - # return self._read_fn(filename) - return chain.from_iterable( - [self._read_fn(filename) for filename in self.filenames]) - - -class ThePileConfig(tfds.core.BuilderConfig): - - def __init__(self, *, mode=None, **kwargs): - super(ThePileConfig, self).__init__(name=mode, - description="The Pile dataset", - **kwargs) - - -class ThePile(tfds.core.GeneratorBasedBuilder): - BUILDER_CONFIGS = [ - ThePileConfig(version=_VERSION, mode=mode) for mode in _DATASET_MODES - ] - - def _info(self) -> tfds.core.DatasetInfo: - return tfds.core.DatasetInfo( - builder=self, - description=_DESCRIPTION, - features=tfds.features.FeaturesDict({'text': tfds.features.Text()}), - supervised_keys=("text", "text"), - homepage='https://github.com/EleutherAI/The-Pile', - citation=_CITATION, - ) - - def _split_generators(self, dl_manager: tfds.download.DownloadManager): - dl_manager.verify_ssl = False - dl_paths = dl_manager.download(_URLS['the_pile']) - print(dl_paths) - return { - 'train': self._generate_examples(dl_paths['train']), - 'validation': self._generate_examples(dl_paths['validation']), - 'test': self._generate_examples(dl_paths['test']), - } - - def _generate_examples(self, paths): - pipeline = PileReader(paths) - #print('pipeline', pipeline) - for x, result in enumerate(pipeline): - if result: - idx = f'{x}_the_pile' - yield idx, {'text': result['text']} diff --git a/t5x-main/t5x/contrib/gpu/t5/README.md b/t5x-main/t5x/contrib/gpu/t5/README.md deleted file mode 100644 index bcabd31410b413909d05e5f0d8bd5f26d020e29c..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/t5/README.md +++ /dev/null @@ -1,6 +0,0 @@ -This directory contains model implementations for the T5-variants (T5.1.1, -T5.1.0, mT5, ByT5). All variants share the neural network implementation in -`network.py`, which has a minimal set of configurables in `TransformerConfig`. - -Refer to the [main -README](https://github.com/google-research/t5x/blob/main/README.md) for the example usages. diff --git a/t5x-main/t5x/contrib/gpu/t5/__init__.py b/t5x-main/t5x/contrib/gpu/t5/__init__.py deleted file mode 100644 index da022c16301721a096a208e8bdb2a71bb87f9788..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/t5/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright 2022 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This empty file is needed for loading the gin files in this directory. diff --git a/t5x-main/t5x/contrib/gpu/t5/byt5/__init__.py b/t5x-main/t5x/contrib/gpu/t5/byt5/__init__.py deleted file mode 100644 index da022c16301721a096a208e8bdb2a71bb87f9788..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/t5/byt5/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright 2022 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This empty file is needed for loading the gin files in this directory. diff --git a/t5x-main/t5x/contrib/gpu/t5/byt5/base.gin b/t5x-main/t5x/contrib/gpu/t5/byt5/base.gin deleted file mode 100644 index f58325490509ebb922e3a8017f9873cb6554b67d..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/t5/byt5/base.gin +++ /dev/null @@ -1,54 +0,0 @@ -# ByT5 Base model. -from __gin__ import dynamic_registration - -import seqio -from t5x import adafactor -from t5x import models -from t5x.contrib.gpu.t5 import network - -# ------------------- Loss HParam ---------------------------------------------- -Z_LOSS = 0.0001 -LABEL_SMOOTHING = 0.0 -# NOTE: When fine-tuning the public T5 checkpoints (trained in T5 MeshTF) -# the loss normalizing factor should be set to pretraining batch_size * -# target_token_length. -LOSS_NORMALIZING_FACTOR = None -# Dropout should be specified in the "run" files -DROPOUT_RATE = %gin.REQUIRED - -# Vocabulary (shared by encoder and decoder) -VOCABULARY = @seqio.ByteVocabulary() - -# ------------------- Optimizer ------------------------------------------------ -# `learning_rate` is set by `Trainer.learning_rate_fn`. -OPTIMIZER = @adafactor.Adafactor() -adafactor.Adafactor: - decay_rate = 0.8 - step_offset = 0 - logical_factor_rules = @adafactor.standard_logical_factor_rules() - -# ------------------- Model ---------------------------------------------------- -MODEL = @models.EncoderDecoderModel() -models.EncoderDecoderModel: - module = @network.Transformer() - input_vocabulary = %VOCABULARY - output_vocabulary = %VOCABULARY - optimizer_def = %OPTIMIZER - z_loss = %Z_LOSS - label_smoothing = %LABEL_SMOOTHING - loss_normalizing_factor = %LOSS_NORMALIZING_FACTOR - -# ------------------- Network specification ------------------------------------ -network.Transformer.config = @network.T5Config() -network.T5Config: - vocab_size = 384 # vocab size rounded to a multiple of 128 for TPU efficiency - dtype = 'bfloat16' - emb_dim = 1536 - num_heads = 12 - num_encoder_layers = 18 - num_decoder_layers = 6 - head_dim = 64 - mlp_dim = 3968 - mlp_activations = ('gelu', 'linear') - dropout_rate = %DROPOUT_RATE - logits_via_embedding = False diff --git a/t5x-main/t5x/contrib/gpu/t5/byt5/large.gin b/t5x-main/t5x/contrib/gpu/t5/byt5/large.gin deleted file mode 100644 index 4be8bab7c49c87f2233180b85e37ff47aa0cc67e..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/t5/byt5/large.gin +++ /dev/null @@ -1,13 +0,0 @@ -# ByT5 Large model. - -include 't5x/contrib/gpu/t5/byt5/base.gin' # imports vocab, optimizer and model. - -# ------------------- Network specification overrides -------------------------- -network.Transformer.config = @network.T5Config() -network.T5Config: - emb_dim = 1536 - num_heads = 16 - num_encoder_layers = 36 - num_decoder_layers = 12 - head_dim = 64 - mlp_dim = 3840 diff --git a/t5x-main/t5x/contrib/gpu/t5/byt5/small.gin b/t5x-main/t5x/contrib/gpu/t5/byt5/small.gin deleted file mode 100644 index 52663bcfd164c40aab037a09f3f1098e837018cf..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/t5/byt5/small.gin +++ /dev/null @@ -1,13 +0,0 @@ -# ByT5 Small model. - -include 't5x/contrib/gpu/t5/byt5/base.gin' # imports vocab, optimizer and model. - -# ------------------- Network specification overrides -------------------------- -network.Transformer.config = @network.T5Config() -network.T5Config: - emb_dim = 1472 - num_heads = 6 - num_encoder_layers = 12 - num_decoder_layers = 4 - head_dim = 64 - mlp_dim = 3584 diff --git a/t5x-main/t5x/contrib/gpu/t5/byt5/tiny.gin b/t5x-main/t5x/contrib/gpu/t5/byt5/tiny.gin deleted file mode 100644 index 04ad2c1ec22c0da73cde09cf9ea691e3bdf2fd63..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/t5/byt5/tiny.gin +++ /dev/null @@ -1,13 +0,0 @@ -# T5.1.1 tiny model. - -include 't5x/contrib/gpu/t5/t5_1_1/base.gin' # imports vocab, optimizer and model. - -# ------------------- Network specification overrides -------------------------- -network.Transformer.config = @network.T5Config() -network.T5Config: - emb_dim = 8 - num_heads = 4 - num_encoder_layers = 2 - num_decoder_layers = 2 - head_dim = 3 - mlp_dim = 16 diff --git a/t5x-main/t5x/contrib/gpu/t5/byt5/xl.gin b/t5x-main/t5x/contrib/gpu/t5/byt5/xl.gin deleted file mode 100644 index af07c4f941b6d0c29e7de0c3c2db8a3a900c07b9..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/t5/byt5/xl.gin +++ /dev/null @@ -1,13 +0,0 @@ -# ByT5 XL model. - -include 't5x/contrib/gpu/t5/byt5/base.gin' # imports vocab, optimizer and model. - -# ------------------- Network specification overrides -------------------------- -network.Transformer.config = @network.T5Config() -network.T5Config: - emb_dim = 2560 - num_heads = 32 - num_encoder_layers = 36 - num_decoder_layers = 12 - head_dim = 64 - mlp_dim = 6720 diff --git a/t5x-main/t5x/contrib/gpu/t5/byt5/xxl.gin b/t5x-main/t5x/contrib/gpu/t5/byt5/xxl.gin deleted file mode 100644 index 49c7a89fa050d572a92935b8fd92370e1812078c..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/t5/byt5/xxl.gin +++ /dev/null @@ -1,13 +0,0 @@ -# ByT5 XXL model. - -include 't5x/contrib/gpu/t5/byt5/base.gin' # imports vocab, optimizer and model. - -# ------------------- Network specification overrides -------------------------- -network.Transformer.config = @network.T5Config() -network.T5Config: - emb_dim = 4672 - num_heads = 64 - num_encoder_layers = 36 - num_decoder_layers = 12 - head_dim = 64 - mlp_dim = 12352 diff --git a/t5x-main/t5x/contrib/gpu/t5/configs/runs/__init__.py b/t5x-main/t5x/contrib/gpu/t5/configs/runs/__init__.py deleted file mode 100644 index da022c16301721a096a208e8bdb2a71bb87f9788..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/t5/configs/runs/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright 2022 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This empty file is needed for loading the gin files in this directory. diff --git a/t5x-main/t5x/contrib/gpu/t5/configs/runs/eval.gin b/t5x-main/t5x/contrib/gpu/t5/configs/runs/eval.gin deleted file mode 100644 index 278b92e7ca51d4a12785b4befb11d85aea400e2c..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/t5/configs/runs/eval.gin +++ /dev/null @@ -1,68 +0,0 @@ -# Defaults for eval.py. -# -# -# You must also include a binding for MODEL. -# -# Required to be set: -# -# - MIXTURE_OR_TASK_NAME: The SeqIO Task/Mixture to evaluate on -# - CHECKPOINT_PATH: The model checkpoint to evaluate -# - EVAL_OUTPUT_DIR: The dir to write results to. -# -# -# Commonly overridden options: -# -# - DatasetConfig.split -# - DatasetConfig.batch_size -# - DatasetConfig.use_cached -# - RestoreCheckpointConfig.mode -# - PjitPartitioner.num_partitions -from __gin__ import dynamic_registration - -import __main__ as eval_script -import seqio -from t5x import partitioning -from t5x import utils - - -# Must be overridden -MIXTURE_OR_TASK_NAME = %gin.REQUIRED -CHECKPOINT_PATH = %gin.REQUIRED -EVAL_OUTPUT_DIR = %gin.REQUIRED -TASK_FEATURE_LENGTHS = None # auto-computes the maximum features length to use. - -# DEPRECATED: Import the this module in your gin file. -MIXTURE_OR_TASK_MODULE = None - -eval_script.evaluate: - model = %MODEL # imported from separate gin file - dataset_cfg = @utils.DatasetConfig() - partitioner = @partitioning.PjitPartitioner() - restore_checkpoint_cfg = @utils.RestoreCheckpointConfig() - output_dir = %EVAL_OUTPUT_DIR - inference_evaluator_cls = @seqio.Evaluator - -partitioning.PjitPartitioner: - num_partitions = 1 - logical_axis_rules = @partitioning.standard_logical_axis_rules() - -seqio.Evaluator: - logger_cls = [@seqio.PyLoggingLogger, @seqio.TensorBoardLogger, @seqio.JSONLogger] - num_examples = None # Use all examples in the dataset. - use_memory_cache = True - -utils.DatasetConfig: - mixture_or_task_name = %MIXTURE_OR_TASK_NAME - task_feature_lengths = %TASK_FEATURE_LENGTHS - split = 'test' - batch_size = 32 - shuffle = False - seed = 42 - use_cached = False - pack = False - use_custom_packing_ops = False - module = %MIXTURE_OR_TASK_MODULE - -utils.RestoreCheckpointConfig: - path = %CHECKPOINT_PATH - mode = 'specific' diff --git a/t5x-main/t5x/contrib/gpu/t5/configs/runs/export.gin b/t5x-main/t5x/contrib/gpu/t5/configs/runs/export.gin deleted file mode 100644 index 53177ced24b1cdc6da3afff22a3a4ec64572aecb..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/t5/configs/runs/export.gin +++ /dev/null @@ -1,93 +0,0 @@ -# Defaults for single_core_export.py. -# -# You must also include a binding for MODEL. -# -# Required to be set: -# -# - TASK_FEATURE_LENGTHS: The lengths per key in the SeqIO Task to trim features -# to. -# - CHECKPOINT_PATH: The model checkpoint to use for inference -# - INFER_OUTPUT_DIR: The dir to write results to. When launching using -# XManager, this is set automatically. -# -# Commonly overridden options: -# -# warmup_examples: Optional[List[str]] = None -# jit_compile: bool = False - -from __gin__ import dynamic_registration - -import seqio - -from t5x import checkpoints -from t5x import models -from t5x import partitioning -from t5x import utils -from t5x import export_lib - -# Must be overridden -OUTPUT_FEATURES = %gin.REQUIRED -TASK_FEATURE_LENGTHS = %gin.REQUIRED -CHECKPOINT_PATH = %gin.REQUIRED -MODEL_OUTPUT_DIR = %gin.REQUIRED -MODEL_NAME = %gin.REQUIRED -BATCH_SIZE = None -BEAM_SIZE = 1 - -OUTPUT_FEATURES = {'inputs': @inputs/seqio.Feature(), 'targets': @outputs/seqio.Feature()} - -# Plumbing to extract the vocabulary directly from MODEL. This is needed to -# tokenize the features from the saved model inputs we aren't provided with -# vocabularies via a Task. -inputs/seqio.Feature.vocabulary = @models.get_input_vocabulary() -models.get_input_vocabulary.model = %MODEL # imported from separate gin file -outputs/seqio.Feature.vocabulary = @models.get_output_vocabulary() -models.get_output_vocabulary.model = %MODEL # imported from separate gin file - - -# Typical for inference settings: -ACTIVATION_DTYPE = 'bfloat16' - -export_lib.save: - model = %MODEL # imported from separate gin file - inference_mode = 'predict' - restore_checkpoint_cfg = @utils.RestoreCheckpointConfig() - exportable_module_cls = @export_lib.ExportableModule - create_preprocessor_fn = @export_lib.create_preprocessor - create_postprocessor_fn = @export_lib.create_postprocessor - write_warmup_example_fn = @export_lib.write_warmup_examples - partitioner = @partitioning.PjitPartitioner() - output_features = %OUTPUT_FEATURES - task_feature_lengths = %TASK_FEATURE_LENGTHS - output_dir = %MODEL_OUTPUT_DIR - model_name = %MODEL_NAME - batch_size = %BATCH_SIZE - native_lowering = False - -utils.RestoreCheckpointConfig: - path = %CHECKPOINT_PATH - mode = 'specific' - dtype = 'bfloat16' - checkpointer_cls = @checkpoints.Checkpointer - # TODO(b/234480674): GDA disabled due to incompatibility with export. - use_gda = False - -export_lib.create_preprocessor: - output_features = %OUTPUT_FEATURES - task_feature_lengths = %TASK_FEATURE_LENGTHS - -export_lib.create_postprocessor: - output_feature_names = None - -export_lib.ExportableModule: - jit_compile = True - use_batch_function = False - -partitioning.PjitPartitioner: - num_partitions = 1 - params_on_devices = True - logical_axis_rules = @partitioning.standard_logical_axis_rules() - -models.EncoderDecoderModel.predict_batch_with_aux: - num_decodes = %BEAM_SIZE - return_all_decodes = True diff --git a/t5x-main/t5x/contrib/gpu/t5/configs/runs/export_seqio.gin b/t5x-main/t5x/contrib/gpu/t5/configs/runs/export_seqio.gin deleted file mode 100644 index 04fcc70c4641ea8ab5093dd863379fac2bc43889..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/t5/configs/runs/export_seqio.gin +++ /dev/null @@ -1,20 +0,0 @@ -from __gin__ import dynamic_registration - -from t5x import export_lib -from t5x import partitioning - -include 't5x/contrib/gpu/t5/configs/runs/export.gin' - -MIXTURE_OR_TASK_NAME = %gin.REQUIRED - -export_lib.save: - create_preprocessor_fn = @export_lib.create_preprocessor_from_task - mixture_or_task_name = %MIXTURE_OR_TASK_NAME - output_features = None - -export_lib.create_preprocessor_from_task: - model = %MODEL - task_feature_lengths = %TASK_FEATURE_LENGTHS - task_name = %MIXTURE_OR_TASK_NAME - serialized_examples = True - run_precache = False diff --git a/t5x-main/t5x/contrib/gpu/t5/configs/runs/finetune.gin b/t5x-main/t5x/contrib/gpu/t5/configs/runs/finetune.gin deleted file mode 100644 index a76d80957741d4f9b6c710101ed29bf2bf430423..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/t5/configs/runs/finetune.gin +++ /dev/null @@ -1,150 +0,0 @@ -# Defaults for finetuning with train.py. -# -# -# You must also include a binding for MODEL. -# -# Required to be set: -# -# - MIXTURE_OR_TASK_NAME -# - TASK_FEATURE_LENGTHS -# - TRAIN_STEPS # includes pretrain steps -# - MODEL_DIR # automatically set when using xm_launch -# - INITIAL_CHECKPOINT_PATH -# -# When running locally, it needs to be passed in the `gin.MODEL_DIR` flag. -# -# `TRAIN_STEPS` should include pre-training steps, e.g., if pre-trained ckpt -# has 1M steps, TRAIN_STEPS = 1.1M will perform 0.1M fine-tuning steps. -# -# Commonly overridden options: -# - DROPOUT_RATE -# - BATCH_SIZE -# - PjitPartitioner.num_partitions -# - Trainer.num_microbatches -# - USE_CACHED_TASKS: Whether to look for preprocessed SeqIO data, or preprocess -# on the fly. Most common tasks are cached, hence this is set to True by -# default. - -from __gin__ import dynamic_registration - -import __main__ as train_script -import seqio -from t5x import gin_utils -from t5x import partitioning -from t5x import utils -from t5x import trainer - -# Must be overridden -MODEL_DIR = %gin.REQUIRED -MIXTURE_OR_TASK_NAME = %gin.REQUIRED -TASK_FEATURE_LENGTHS = %gin.REQUIRED -MIXTURE_OR_TASK_MODULE = %gin.REQUIRED -TRAIN_STEPS = %gin.REQUIRED -INITIAL_CHECKPOINT_PATH = %gin.REQUIRED - -# Commonly overridden -DROPOUT_RATE = 0.1 -USE_CACHED_TASKS = True -BATCH_SIZE = 128 - -# Sometimes overridden -EVAL_STEPS = 20 -EVAL_PERIOD = 1000 - -# Convenience overrides. -EVALUATOR_USE_MEMORY_CACHE = True -EVALUATOR_NUM_EXAMPLES = None # Use all examples in the infer_eval dataset. -JSON_WRITE_N_RESULTS = None # Write all inferences. -# HW RNG is faster than SW, but has limited determinism. -# Most notably it is not deterministic across different -# submeshes. -USE_HARDWARE_RNG = False -# None always uses faster, hardware RNG -RANDOM_SEED = None - -# DEPRECATED: Import the this module in your gin file. -MIXTURE_OR_TASK_MODULE = None - -train_script.train: - model = %MODEL # imported from separate gin file - model_dir = %MODEL_DIR - train_dataset_cfg = @train/utils.DatasetConfig() - train_eval_dataset_cfg = @train_eval/utils.DatasetConfig() - infer_eval_dataset_cfg = @infer_eval/utils.DatasetConfig() - checkpoint_cfg = @utils.CheckpointConfig() - partitioner = @partitioning.PjitPartitioner() - trainer_cls = @trainer.Trainer - total_steps = %TRAIN_STEPS - eval_steps = %EVAL_STEPS - eval_period = %EVAL_PERIOD - random_seed = %RANDOM_SEED - use_hardware_rng = %USE_HARDWARE_RNG - summarize_config_fn = @gin_utils.summarize_gin_config - inference_evaluator_cls = @seqio.Evaluator - -partitioning.PjitPartitioner: - num_partitions = 1 - model_parallel_submesh = None - logical_axis_rules = @partitioning.standard_logical_axis_rules() - -seqio.Evaluator: - logger_cls = [@seqio.PyLoggingLogger, @seqio.TensorBoardLogger, @seqio.JSONLogger] - num_examples = %EVALUATOR_NUM_EXAMPLES - use_memory_cache = %EVALUATOR_USE_MEMORY_CACHE - -seqio.JSONLogger: - write_n_results = %JSON_WRITE_N_RESULTS - -train/utils.DatasetConfig: - mixture_or_task_name = %MIXTURE_OR_TASK_NAME - task_feature_lengths = %TASK_FEATURE_LENGTHS - split = 'train' - batch_size = %BATCH_SIZE - shuffle = True - seed = None # use a new seed each run/restart - use_cached = %USE_CACHED_TASKS - pack = True - module = %MIXTURE_OR_TASK_MODULE - -train_eval/utils.DatasetConfig: - mixture_or_task_name = %MIXTURE_OR_TASK_NAME - task_feature_lengths = %TASK_FEATURE_LENGTHS - split = 'validation' - batch_size = %BATCH_SIZE - shuffle = False - seed = 42 - use_cached = %USE_CACHED_TASKS - pack = True - module = %MIXTURE_OR_TASK_MODULE - -infer_eval/utils.DatasetConfig: - mixture_or_task_name = %MIXTURE_OR_TASK_NAME - task_feature_lengths = None # compute max - split = 'validation' - batch_size = %BATCH_SIZE - shuffle = False - seed = 42 - use_cached = %USE_CACHED_TASKS - pack = False - module = %MIXTURE_OR_TASK_MODULE - -utils.CheckpointConfig: - restore = @utils.RestoreCheckpointConfig() - save = @utils.SaveCheckpointConfig() -utils.RestoreCheckpointConfig: - path = %INITIAL_CHECKPOINT_PATH - mode = 'specific' - dtype = 'float32' -utils.SaveCheckpointConfig: - period = 5000 - dtype = 'float32' - keep = None # keep all checkpoints - save_dataset = False # don't checkpoint dataset state - -trainer.Trainer: - num_microbatches = None - learning_rate_fn = @utils.create_learning_rate_scheduler() -utils.create_learning_rate_scheduler: - factors = 'constant' - base_learning_rate = 0.001 - warmup_steps = 1000 diff --git a/t5x-main/t5x/contrib/gpu/t5/configs/runs/finetune_mnli.gin b/t5x-main/t5x/contrib/gpu/t5/configs/runs/finetune_mnli.gin deleted file mode 100644 index 7cebbb6d9141b06877dfa7c3ae4cafef0f374316..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/t5/configs/runs/finetune_mnli.gin +++ /dev/null @@ -1,151 +0,0 @@ -# Defaults for finetuning with train.py. -# -# -# You must also include a binding for MODEL. -# -# Required to be set: -# -# - MIXTURE_OR_TASK_NAME -# - TASK_FEATURE_LENGTHS -# - TRAIN_STEPS # includes pretrain steps -# - MODEL_DIR # automatically set when using xm_launch -# - INITIAL_CHECKPOINT_PATH -# -# When running locally, it needs to be passed in the `gin.MODEL_DIR` flag. -# -# `TRAIN_STEPS` should include pre-training steps, e.g., if pre-trained ckpt -# has 1M steps, TRAIN_STEPS = 1.1M will perform 0.1M fine-tuning steps. -# -# Commonly overridden options: -# - DROPOUT_RATE -# - BATCH_SIZE -# - PjitPartitioner.num_partitions -# - Trainer.num_microbatches -# - USE_CACHED_TASKS: Whether to look for preprocessed SeqIO data, or preprocess -# on the fly. Most common tasks are cached, hence this is set to True by -# default. - -from __gin__ import dynamic_registration - -import __main__ as train_script -import seqio -from t5x import gin_utils -from t5x import partitioning -from t5x import utils -from t5x import trainer - -# Must be overridden -MODEL_DIR = %gin.REQUIRED -MIXTURE_OR_TASK_NAME = %gin.REQUIRED -TASK_FEATURE_LENGTHS = %gin.REQUIRED -MIXTURE_OR_TASK_MODULE = %gin.REQUIRED -TRAIN_STEPS = %gin.REQUIRED -INITIAL_CHECKPOINT_PATH = %gin.REQUIRED - -# Commonly overridden -DROPOUT_RATE = 0.1 -USE_CACHED_TASKS = False -BATCH_SIZE = 128 - -# Sometimes overridden -EVAL_STEPS = 20 - -# Convenience overrides. -EVALUATOR_USE_MEMORY_CACHE = True -EVALUATOR_NUM_EXAMPLES = None # Use all examples in the infer_eval dataset. -JSON_WRITE_N_RESULTS = None # Write all inferences. -# HW RNG is faster than SW, but has limited determinism. -# Most notably it is not deterministic across different -# submeshes. -USE_HARDWARE_RNG = False -# None always uses faster, hardware RNG -RANDOM_SEED = None - -# DEPRECATED: Import the this module in your gin file. -MIXTURE_OR_TASK_MODULE = None - -train_script.train: - model = %MODEL # imported from separate gin file - model_dir = %MODEL_DIR - train_dataset_cfg = @train/utils.DatasetConfig() - train_eval_dataset_cfg = @train_eval/utils.DatasetConfig() - infer_eval_dataset_cfg = @infer_eval/utils.DatasetConfig() - checkpoint_cfg = @utils.CheckpointConfig() - partitioner = @partitioning.PjitPartitioner() - trainer_cls = @trainer.Trainer - total_steps = %TRAIN_STEPS - eval_steps = %EVAL_STEPS - eval_period = 1000 - random_seed = %RANDOM_SEED - use_hardware_rng = %USE_HARDWARE_RNG - summarize_config_fn = @gin_utils.summarize_gin_config - inference_evaluator_cls = @seqio.Evaluator - -partitioning.PjitPartitioner: - num_partitions = 1 - model_parallel_submesh = None - logical_axis_rules = @partitioning.standard_logical_axis_rules() - -seqio.Evaluator: - logger_cls = [@seqio.PyLoggingLogger, @seqio.TensorBoardLogger, @seqio.JSONLogger] - num_examples = %EVALUATOR_NUM_EXAMPLES - use_memory_cache = %EVALUATOR_USE_MEMORY_CACHE - -seqio.JSONLogger: - write_n_results = %JSON_WRITE_N_RESULTS - -train/utils.DatasetConfig: - mixture_or_task_name = %MIXTURE_OR_TASK_NAME - task_feature_lengths = %TASK_FEATURE_LENGTHS - split = 'train' - batch_size = %BATCH_SIZE - shuffle = True - seed = None # use a new seed each run/restart - use_cached = %USE_CACHED_TASKS - pack = True - module = %MIXTURE_OR_TASK_MODULE - -train_eval/utils.DatasetConfig: - mixture_or_task_name = %MIXTURE_OR_TASK_NAME - task_feature_lengths = %TASK_FEATURE_LENGTHS - split = 'validation_mismatched' - batch_size = %BATCH_SIZE - shuffle = False - seed = 42 - use_cached = %USE_CACHED_TASKS - pack = True - module = %MIXTURE_OR_TASK_MODULE - -infer_eval/utils.DatasetConfig: - mixture_or_task_name = %MIXTURE_OR_TASK_NAME - task_feature_lengths = None # compute max - split = 'validation_matched' - batch_size = %BATCH_SIZE - shuffle = False - seed = 42 - use_cached = %USE_CACHED_TASKS - pack = False - module = %MIXTURE_OR_TASK_MODULE - -utils.CheckpointConfig: - restore = @utils.RestoreCheckpointConfig() - save = @utils.SaveCheckpointConfig() -utils.RestoreCheckpointConfig: - path = %INITIAL_CHECKPOINT_PATH - #mode = 'specific' - dtype = 'float32' -utils.SaveCheckpointConfig: - period = 5000 - dtype = 'float32' - keep = None # keep all checkpoints - save_dataset = False # don't checkpoint dataset state - -trainer.Trainer: - num_microbatches = None - learning_rate_fn = @utils.create_learning_rate_scheduler() -utils.create_learning_rate_scheduler: - factors = 'linear_decay' - base_learning_rate = 0.001 - warmup_steps = 1001000 - min_learning_rate=0.00001 - decay_factor = 1.67e-5 diff --git a/t5x-main/t5x/contrib/gpu/t5/configs/runs/finetune_squad1.gin b/t5x-main/t5x/contrib/gpu/t5/configs/runs/finetune_squad1.gin deleted file mode 100644 index 4ea952c910827a2432cd64ff0f313093bd58b711..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/t5/configs/runs/finetune_squad1.gin +++ /dev/null @@ -1,151 +0,0 @@ -# Defaults for finetuning with train.py. -# -# -# You must also include a binding for MODEL. -# -# Required to be set: -# -# - MIXTURE_OR_TASK_NAME -# - TASK_FEATURE_LENGTHS -# - TRAIN_STEPS # includes pretrain steps -# - MODEL_DIR # automatically set when using xm_launch -# - INITIAL_CHECKPOINT_PATH -# -# When running locally, it needs to be passed in the `gin.MODEL_DIR` flag. -# -# `TRAIN_STEPS` should include pre-training steps, e.g., if pre-trained ckpt -# has 1M steps, TRAIN_STEPS = 1.1M will perform 0.1M fine-tuning steps. -# -# Commonly overridden options: -# - DROPOUT_RATE -# - BATCH_SIZE -# - PjitPartitioner.num_partitions -# - Trainer.num_microbatches -# - USE_CACHED_TASKS: Whether to look for preprocessed SeqIO data, or preprocess -# on the fly. Most common tasks are cached, hence this is set to True by -# default. - -from __gin__ import dynamic_registration - -import __main__ as train_script -import seqio -from t5x import gin_utils -from t5x import partitioning -from t5x import utils -from t5x import trainer - -# Must be overridden -MODEL_DIR = %gin.REQUIRED -MIXTURE_OR_TASK_NAME = %gin.REQUIRED -TASK_FEATURE_LENGTHS = %gin.REQUIRED -MIXTURE_OR_TASK_MODULE = %gin.REQUIRED -TRAIN_STEPS = %gin.REQUIRED -INITIAL_CHECKPOINT_PATH = %gin.REQUIRED - -# Commonly overridden -DROPOUT_RATE = 0.1 -USE_CACHED_TASKS = False -BATCH_SIZE = 128 - -# Sometimes overridden -EVAL_STEPS = 20 - -# Convenience overrides. -EVALUATOR_USE_MEMORY_CACHE = True -EVALUATOR_NUM_EXAMPLES = None # Use all examples in the infer_eval dataset. -JSON_WRITE_N_RESULTS = None # Write all inferences. -# HW RNG is faster than SW, but has limited determinism. -# Most notably it is not deterministic across different -# submeshes. -USE_HARDWARE_RNG = False -# None always uses faster, hardware RNG -RANDOM_SEED = None - -# DEPRECATED: Import the this module in your gin file. -MIXTURE_OR_TASK_MODULE = None - -train_script.train: - model = %MODEL # imported from separate gin file - model_dir = %MODEL_DIR - train_dataset_cfg = @train/utils.DatasetConfig() - train_eval_dataset_cfg = @train_eval/utils.DatasetConfig() - infer_eval_dataset_cfg = @infer_eval/utils.DatasetConfig() - checkpoint_cfg = @utils.CheckpointConfig() - partitioner = @partitioning.PjitPartitioner() - trainer_cls = @trainer.Trainer - total_steps = %TRAIN_STEPS - eval_steps = %EVAL_STEPS - eval_period = 1000 - random_seed = %RANDOM_SEED - use_hardware_rng = %USE_HARDWARE_RNG - summarize_config_fn = @gin_utils.summarize_gin_config - inference_evaluator_cls = @seqio.Evaluator - -partitioning.PjitPartitioner: - num_partitions = 1 - model_parallel_submesh = None - logical_axis_rules = @partitioning.standard_logical_axis_rules() - -seqio.Evaluator: - logger_cls = [@seqio.PyLoggingLogger, @seqio.TensorBoardLogger, @seqio.JSONLogger] - num_examples = %EVALUATOR_NUM_EXAMPLES - use_memory_cache = %EVALUATOR_USE_MEMORY_CACHE - -seqio.JSONLogger: - write_n_results = %JSON_WRITE_N_RESULTS - -train/utils.DatasetConfig: - mixture_or_task_name = %MIXTURE_OR_TASK_NAME - task_feature_lengths = %TASK_FEATURE_LENGTHS - split = 'train' - batch_size = %BATCH_SIZE - shuffle = True - seed = None # use a new seed each run/restart - use_cached = %USE_CACHED_TASKS - pack = True - module = %MIXTURE_OR_TASK_MODULE - -train_eval/utils.DatasetConfig: - mixture_or_task_name = %MIXTURE_OR_TASK_NAME - task_feature_lengths = %TASK_FEATURE_LENGTHS - split = 'validation' - batch_size = %BATCH_SIZE - shuffle = False - seed = 42 - use_cached = %USE_CACHED_TASKS - pack = True - module = %MIXTURE_OR_TASK_MODULE - -infer_eval/utils.DatasetConfig: - mixture_or_task_name = %MIXTURE_OR_TASK_NAME - task_feature_lengths = None # compute max - split = 'validation' - batch_size = %BATCH_SIZE - shuffle = False - seed = 42 - use_cached = %USE_CACHED_TASKS - pack = False - module = %MIXTURE_OR_TASK_MODULE - -utils.CheckpointConfig: - restore = @utils.RestoreCheckpointConfig() - save = @utils.SaveCheckpointConfig() -utils.RestoreCheckpointConfig: - path = %INITIAL_CHECKPOINT_PATH - #mode = 'specific' - dtype = 'float32' -utils.SaveCheckpointConfig: - period = 5000 - dtype = 'float32' - keep = None # keep all checkpoints - save_dataset = False # don't checkpoint dataset state - -trainer.Trainer: - num_microbatches = None - learning_rate_fn = @utils.create_learning_rate_scheduler() -utils.create_learning_rate_scheduler: - factors = 'linear_decay' - base_learning_rate = 0.001 - warmup_steps = 1001000 - min_learning_rate=0.00001 - decay_factor = 3.75e-5 diff --git a/t5x-main/t5x/contrib/gpu/t5/configs/runs/infer.gin b/t5x-main/t5x/contrib/gpu/t5/configs/runs/infer.gin deleted file mode 100644 index 0918d2f4843d698cf27787c62f7e09cf81c1e835..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/t5/configs/runs/infer.gin +++ /dev/null @@ -1,71 +0,0 @@ -# Defaults for infer.py. -# -# -# You must also include a binding for MODEL. -# -# Required to be set: -# -# - MIXTURE_OR_TASK_NAME: The SeqIO Task/Mixture to use for inference -# - TASK_FEATURE_LENGTHS: The lengths per key in the SeqIO Task to trim features -# to. -# - CHECKPOINT_PATH: The model checkpoint to use for inference -# - INFER_OUTPUT_DIR: The dir to write results to. -# -# -# Commonly overridden options: -# -# - infer.mode -# - infer.checkpoint_period -# - infer.shard_id -# - infer.num_shards -# - DatasetConfig.split -# - DatasetConfig.batch_size -# - DatasetConfig.use_cached -# - RestoreCheckpointConfig.is_tensorflow -# - RestoreCheckpointConfig.mode -# - PjitPartitioner.num_partitions -from __gin__ import dynamic_registration - -import __main__ as infer_script -from t5x import partitioning -from t5x import utils - -# Must be overridden -MIXTURE_OR_TASK_NAME = %gin.REQUIRED -TASK_FEATURE_LENGTHS = %gin.REQUIRED -CHECKPOINT_PATH = %gin.REQUIRED -INFER_OUTPUT_DIR = %gin.REQUIRED - -# DEPRECATED: Import the this module in your gin file. -MIXTURE_OR_TASK_MODULE = None - -infer_script.infer: - mode = 'predict' - model = %MODEL # imported from separate gin file - output_dir = %INFER_OUTPUT_DIR - dataset_cfg = @utils.DatasetConfig() - partitioner = @partitioning.PjitPartitioner() - restore_checkpoint_cfg = @utils.RestoreCheckpointConfig() - checkpoint_period = 100 - shard_id = 0 - num_shards = 1 - -partitioning.PjitPartitioner: - num_partitions = 1 - logical_axis_rules = @partitioning.standard_logical_axis_rules() - -utils.DatasetConfig: - mixture_or_task_name = %MIXTURE_OR_TASK_NAME - module = %MIXTURE_OR_TASK_MODULE - task_feature_lengths = %TASK_FEATURE_LENGTHS - use_cached = False - split = 'test' - batch_size = 32 - shuffle = False - seed = 0 - pack = False - -utils.RestoreCheckpointConfig: - path = %CHECKPOINT_PATH - mode = 'specific' - dtype = 'bfloat16' diff --git a/t5x-main/t5x/contrib/gpu/t5/configs/runs/infer_from_tfexample_file.gin b/t5x-main/t5x/contrib/gpu/t5/configs/runs/infer_from_tfexample_file.gin deleted file mode 100644 index 5d62b27555ecfef3cd801098fe640ac09eff744c..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/t5/configs/runs/infer_from_tfexample_file.gin +++ /dev/null @@ -1,90 +0,0 @@ -# Defaults for infer.py if using a TFExample file as input. -# -# -# The features from each TFExample are tokenized using the model's vocabulary. -# By default, the inputs feature is assumed to be keyed as 'inputs', but this -# can be overridden with `create_task_from_tfexample_file.inputs_key`. -# -# You must also include a binding for MODEL. -# -# Required to be set: -# -# - TF_EXAMPLE_FILE_PATHS: The path to read TF Examples from. -# - TF_EXAMPLE_FILE_TYPE: The type of file to read TF Examples from. Currently -# supported: 'tfrecord', 'recordio', 'sstable'. -# - FEATURE_LENGTHS: The maximum length per feature in the TF Examples. -# - CHECKPOINT_PATH: The model checkpoint to use for inference -# - INFER_OUTPUT_DIR: The dir to write results to. -# -# -# Commonly overridden options: -# -# - infer.mode -# - infer.checkpoint_period -# - infer.shard_id -# - infer.num_shards -# - create_task_from_tfexample_file.inputs_key -# - create_task_from_tfexample_file.targets_key -# - DatasetConfig.split -# - DatasetConfig.batch_size -# - RestoreCheckpointConfig.mode -# - PjitPartitioner.num_partitions -from __gin__ import dynamic_registration - -import __main__ as infer_script -import seqio -from t5x import models -from t5x import partitioning -from t5x import utils - -# Must be overridden -TF_EXAMPLE_FILE_PATHS = %gin.REQUIRED -TF_EXAMPLE_FILE_TYPE = %gin.REQUIRED -FEATURE_LENGTHS = %gin.REQUIRED -CHECKPOINT_PATH = %gin.REQUIRED -INFER_OUTPUT_DIR = %gin.REQUIRED - -infer_script.infer: - mode = 'predict' - model = %MODEL # imported from separate gin file - output_dir = %INFER_OUTPUT_DIR - dataset_cfg = @utils.DatasetConfig() - partitioner = @partitioning.PjitPartitioner() - restore_checkpoint_cfg = @utils.RestoreCheckpointConfig() - checkpoint_period = 100 - shard_id = 0 - num_shards = 1 - -partitioning.PjitPartitioner: - num_partitions = 1 - logical_axis_rules = @partitioning.standard_logical_axis_rules() - -utils.DatasetConfig: - mixture_or_task_name = @infer_script.create_task_from_tfexample_file() - task_feature_lengths = %FEATURE_LENGTHS - split = 'infer' - batch_size = 32 - shuffle = False - seed = 0 - pack = False - -infer_script.create_task_from_tfexample_file: - paths = %TF_EXAMPLE_FILE_PATHS - file_type = %TF_EXAMPLE_FILE_TYPE - inputs_key = 'inputs' - targets_key = None - features = {'inputs': @inputs/seqio.Feature(), 'targets': @outputs/seqio.Feature()} - -# Plumbing to extract the vocabulary directly from MODEL. This is needed to -# tokenize the features from the TFExample we aren't provided with vocabularies -# via a Task. -inputs/seqio.Feature.vocabulary = @models.get_input_vocabulary() -models.get_input_vocabulary.model = %MODEL -outputs/seqio.Feature.vocabulary = @models.get_output_vocabulary() -models.get_output_vocabulary.model = %MODEL - -utils.RestoreCheckpointConfig: - mode = 'specific' - path = %CHECKPOINT_PATH - dtype = 'bfloat16' - diff --git a/t5x-main/t5x/contrib/gpu/t5/configs/runs/precompile.gin b/t5x-main/t5x/contrib/gpu/t5/configs/runs/precompile.gin deleted file mode 100644 index 787d7d9a0f0a107bcd59ba9e5d83442fd042182b..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/t5/configs/runs/precompile.gin +++ /dev/null @@ -1,59 +0,0 @@ -# Defaults for precompile mode in main.py. -# -# You must also include a binding for MODEL. -# -# Required to be set: -# -# - MIXTURE_OR_TASK_NAME -# - TASK_FEATURE_LENGTHS -# - TRAIN_STEPS -# - MODEL_DIR: # automatically set when using xm_launch -# -# Commonly overridden options: -# -# - USE_CACHED_TASKS -# - BATCH_SIZE -# - PjitPartitioner.num_partitions -from __gin__ import dynamic_registration - -import __main__ as train_script -import seqio -from t5x import gin_utils -from t5x import partitioning -from t5x import utils -from t5x import trainer - -MODEL_DIR = %gin.REQUIRED -MIXTURE_OR_TASK_NAME = %gin.REQUIRED -TASK_FEATURE_LENGTHS = %gin.REQUIRED - - -# Commonly overridden -USE_CACHED_TASKS = True -BATCH_SIZE = 128 - -# None always uses faster, hardware RNG -RANDOM_SEED = None - -train_script.precompile: - model = %MODEL # imported from separate gin file - model_dir = %MODEL_DIR - train_dataset_cfg = @train/utils.DatasetConfig() - partitioner = @partitioning.PjitPartitioner() - random_seed = %RANDOM_SEED - -partitioning.PjitPartitioner: - num_partitions = 1 - model_parallel_submesh = None - backend = "tpu" - logical_axis_rules = @partitioning.standard_logical_axis_rules() - -train/utils.DatasetConfig: - mixture_or_task_name = %MIXTURE_OR_TASK_NAME - task_feature_lengths = %TASK_FEATURE_LENGTHS - split = 'train' - batch_size = %BATCH_SIZE - shuffle = True - seed = None # use a new seed each run/restart - use_cached = %USE_CACHED_TASKS - pack = True diff --git a/t5x-main/t5x/contrib/gpu/t5/configs/runs/pretrain.gin b/t5x-main/t5x/contrib/gpu/t5/configs/runs/pretrain.gin deleted file mode 100644 index de1286467d277237dd06102c2b07cbdd6859d4df..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/t5/configs/runs/pretrain.gin +++ /dev/null @@ -1,108 +0,0 @@ -# Defaults for pretraining with train.py. -# -# -# You must also include a binding for MODEL. -# -# Required to be set: -# -# - MIXTURE_OR_TASK_NAME -# - TASK_FEATURE_LENGTHS -# - TRAIN_STEPS -# - MODEL_DIR: # automatically set when using xm_launch -# -# Commonly overridden options: -# -# - train/DatasetConfig.batch_size -# - train_eval/DatasetConfig.batch_size -# - PjitPartitioner.num_partitions -# - Trainer.num_microbatches -# - DROPOUT_RATE -from __gin__ import dynamic_registration - -import __main__ as train_script -from t5x import gin_utils -from t5x import partitioning -from t5x import utils -from t5x import trainer - -MIXTURE_OR_TASK_NAME = %gin.REQUIRED -TASK_FEATURE_LENGTHS = %gin.REQUIRED -TRAIN_STEPS = %gin.REQUIRED -MODEL_DIR = %gin.REQUIRED -BATCH_SIZE = 128 -USE_CACHED_TASKS = True - -# DEPRECATED: Import the this module in your gin file. -MIXTURE_OR_TASK_MODULE = None -SHUFFLE_TRAIN_EXAMPLES = True - -# HW RNG is faster than SW, but has limited determinism. -# Most notably it is not deterministic across different -# submeshes. -USE_HARDWARE_RNG = False -# None always uses faster, hardware RNG -RANDOM_SEED = None - -# Can be overridden with `train.*`.` -train_script.train: - model = %MODEL # imported from separate gin file - model_dir = %MODEL_DIR - train_dataset_cfg = @train/utils.DatasetConfig() - train_eval_dataset_cfg = @train_eval/utils.DatasetConfig() - infer_eval_dataset_cfg = None - checkpoint_cfg = @utils.CheckpointConfig() - partitioner = @partitioning.PjitPartitioner() - trainer_cls = @trainer.Trainer - total_steps = %TRAIN_STEPS - eval_steps = 20 - eval_period = 1000 - random_seed = %RANDOM_SEED - use_hardware_rng = %USE_HARDWARE_RNG - summarize_config_fn = @gin_utils.summarize_gin_config - -partitioning.PjitPartitioner: - num_partitions = 1 - model_parallel_submesh = None - logical_axis_rules = @partitioning.standard_logical_axis_rules() - -train/utils.DatasetConfig: - mixture_or_task_name = %MIXTURE_OR_TASK_NAME - task_feature_lengths = %TASK_FEATURE_LENGTHS - split = 'train' - batch_size = %BATCH_SIZE - shuffle = %SHUFFLE_TRAIN_EXAMPLES - seed = None # use a new seed each run/restart - use_cached = %USE_CACHED_TASKS - pack = True - module = %MIXTURE_OR_TASK_MODULE - -train_eval/utils.DatasetConfig: - mixture_or_task_name = %MIXTURE_OR_TASK_NAME - task_feature_lengths = %TASK_FEATURE_LENGTHS - split = 'validation' - batch_size = %BATCH_SIZE - shuffle = False - seed = 42 - use_cached = %USE_CACHED_TASKS - pack = True - module = %MIXTURE_OR_TASK_MODULE - -utils.CheckpointConfig: - restore = @utils.RestoreCheckpointConfig() - save = @utils.SaveCheckpointConfig() -utils.RestoreCheckpointConfig: - path = [] # initialize from scratch -utils.SaveCheckpointConfig: - period = 1000 - dtype = 'float32' - keep = None # keep all checkpoints - save_dataset = False # don't checkpoint dataset state - -trainer.Trainer: - num_microbatches = None - learning_rate_fn = @utils.create_learning_rate_scheduler() - -utils.create_learning_rate_scheduler: - factors = 'constant * rsqrt_decay' - base_learning_rate = 1.0 - warmup_steps = 10000 # 10k to keep consistent with T5/MTF defaults. diff --git a/t5x-main/t5x/contrib/gpu/t5/configs/runs/pretrain_pile.gin b/t5x-main/t5x/contrib/gpu/t5/configs/runs/pretrain_pile.gin deleted file mode 100644 index 4031eae06fcd80d3e2965edcbd3aeb4b68e611fb..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/t5/configs/runs/pretrain_pile.gin +++ /dev/null @@ -1,17 +0,0 @@ -include 't5x/contrib/gpu/t5/configs/runs/pretrain.gin' - -USE_CACHED_TASKS = False - -utils.SaveCheckpointConfig: - period = 6000 - dtype = 'float32' - keep = 2 # keep 2 checkpoints - save_dataset = True # checkpoint dataset state - -# This scheduler is made with adam in mind. Use the scheduler from pretrain.gin if using adafactor -utils.create_learning_rate_scheduler: - factors = 'linear_decay' - base_learning_rate = 0.0001 - warmup_steps = 10000 # 10k to keep consistent with T5/MTF defaults. - min_learning_rate = 0.00001 - decay_factor = 9.0909e-7 diff --git a/t5x-main/t5x/contrib/gpu/t5/layers.py b/t5x-main/t5x/contrib/gpu/t5/layers.py deleted file mode 100644 index 383fdab8e43e1b167f81da24377b765c76f4c90c..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/t5/layers.py +++ /dev/null @@ -1,869 +0,0 @@ -# Copyright 2022 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Dense attention classes and mask/weighting functions.""" - -# pylint: disable=attribute-defined-outside-init,g-bare-generic - -import dataclasses -import functools -import operator -from typing import Any, Callable, Iterable, Optional, Sequence, Tuple, Union - -from flax import linen as nn -from flax.linen import partitioning as nn_partitioning -import jax -from jax import lax -from jax import random -import jax.numpy as jnp -import numpy as np - - -# from flax.linen.partitioning import param_with_axes, with_sharding_constraint -param_with_axes = nn_partitioning.param_with_axes -with_sharding_constraint = nn_partitioning.with_sharding_constraint - - -# Type annotations -Array = jnp.ndarray -DType = jnp.dtype -PRNGKey = jnp.ndarray -Shape = Iterable[int] -Activation = Callable[..., Array] -# Parameter initializers. -Initializer = Callable[[PRNGKey, Shape, DType], Array] - -default_embed_init = nn.initializers.variance_scaling( - 1.0, 'fan_in', 'normal', out_axis=0) - - -def dot_product_attention(query: Array, - key: Array, - value: Array, - bias: Optional[Array] = None, - dropout_rng: Optional[PRNGKey] = None, - dropout_rate: float = 0., - deterministic: bool = False, - dtype: DType = jnp.float32, - float32_logits: bool = False): - """Computes dot-product attention given query, key, and value. - - This is the core function for applying attention based on - https://arxiv.org/abs/1706.03762. It calculates the attention weights given - query and key and combines the values using the attention weights. - - Args: - query: queries for calculating attention with shape of `[batch, q_length, - num_heads, qk_depth_per_head]`. - key: keys for calculating attention with shape of `[batch, kv_length, - num_heads, qk_depth_per_head]`. - value: values to be used in attention with shape of `[batch, kv_length, - num_heads, v_depth_per_head]`. - bias: bias for the attention weights. This should be broadcastable to the - shape `[batch, num_heads, q_length, kv_length]` This can be used for - incorporating causal masks, padding masks, proximity bias, etc. - dropout_rng: JAX PRNGKey: to be used for dropout - dropout_rate: dropout rate - deterministic: bool, deterministic or not (to apply dropout) - dtype: the dtype of the computation (default: float32) - float32_logits: bool, if True then compute logits in float32 to avoid - numerical issues with bfloat16. - - Returns: - Output of shape `[batch, length, num_heads, v_depth_per_head]`. - """ - assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.' - assert query.shape[:-3] == key.shape[:-3] == value.shape[:-3], ( - 'q, k, v batch dims must match.') - assert query.shape[-2] == key.shape[-2] == value.shape[-2], ( - 'q, k, v num_heads must match.') - assert key.shape[-3] == value.shape[-3], 'k, v lengths must match.' - assert query.shape[-1] == key.shape[-1], 'q, k depths must match.' - - # Casting logits and softmax computation for float32 for model stability. - if float32_logits: - query = query.astype(jnp.float32) - key = key.astype(jnp.float32) - - # `attn_weights`: [batch, num_heads, q_length, kv_length] - attn_weights = jnp.einsum('bqhd,bkhd->bhqk', query, key) - - # Apply attention bias: masking, dropout, proximity bias, etc. - if bias is not None: - attn_weights = attn_weights + bias.astype(attn_weights.dtype) - - # Normalize the attention weights across `kv_length` dimension. - attn_weights = jax.nn.softmax(attn_weights).astype(dtype) - - # Apply attention dropout. - if not deterministic and dropout_rate > 0.: - keep_prob = 1.0 - dropout_rate - # T5 broadcasts along the "length" dim, but unclear which one that - # corresponds to in positional dimensions here, assuming query dim. - dropout_shape = list(attn_weights.shape) - dropout_shape[-2] = 1 - keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape) - keep = jnp.broadcast_to(keep, attn_weights.shape) - multiplier = ( - keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=dtype)) - attn_weights = attn_weights * multiplier - - # Take the linear combination of `value`. - return jnp.einsum('bhqk,bkhd->bqhd', attn_weights, value) - - -dynamic_vector_slice_in_dim = jax.vmap( - lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None)) - - -class MultiHeadDotProductAttention(nn.Module): - """Multi-head dot-product attention. - - Attributes: - num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1]) - should be divisible by the number of heads. - head_dim: dimension of each head. - dtype: the dtype of the computation. - dropout_rate: dropout rate - kernel_init: initializer for the kernel of the Dense layers. - float32_logits: bool, if True then compute logits in float32 to avoid - numerical issues with bfloat16. - """ - - num_heads: int - head_dim: int - dtype: DType = jnp.float32 - dropout_rate: float = 0. - kernel_init: Initializer = nn.initializers.variance_scaling( - 1.0, 'fan_in', 'normal') - float32_logits: bool = False # computes logits in float32 for stability. - scale_attn_logits: bool = False - - @nn.compact - def __call__(self, - inputs_q: Array, - inputs_kv: Array, - mask: Optional[Array] = None, - bias: Optional[Array] = None, - *, - decode: bool = False, - deterministic: bool = False) -> Array: - """Applies multi-head dot product attention on the input data. - - Projects the inputs into multi-headed query, key, and value vectors, - applies dot-product attention and project the results to an output vector. - - There are two modes: decoding and non-decoding (e.g., training). The mode is - determined by `decode` argument. For decoding, this method is called twice, - first to initialize the cache and then for an actual decoding process. The - two calls are differentiated by the presence of 'cached_key' in the variable - dict. In the cache initialization stage, the cache variables are initialized - as zeros and will be filled in the subsequent decoding process. - - In the cache initialization call, `inputs_q` has a shape [batch, length, - q_features] and `inputs_kv`: [batch, length, kv_features]. During the - incremental decoding stage, query, key and value all have the shape [batch, - 1, qkv_features] corresponding to a single step. - - Args: - inputs_q: input queries of shape `[batch, q_length, q_features]`. - inputs_kv: key/values of shape `[batch, kv_length, kv_features]`. - mask: attention mask of shape `[batch, num_heads, q_length, kv_length]`. - bias: attention bias of shape `[batch, num_heads, q_length, kv_length]`. - decode: Whether to prepare and use an autoregressive cache. - deterministic: Disables dropout if set to True. - - Returns: - output of shape `[batch, length, q_features]`. - """ - projection = functools.partial( - DenseGeneral, - axis=-1, - features=(self.num_heads, self.head_dim), - kernel_axes=('embed', 'joined_kv'), - dtype=self.dtype) - - # NOTE: T5 does not explicitly rescale the attention logits by - # 1/sqrt(depth_kq)! This is folded into the initializers of the - # linear transformations, which is equivalent under Adafactor. - depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype) - query_init = lambda *args: self.kernel_init(*args) / depth_scaling - - # Project inputs_q to multi-headed q/k/v - # dimensions are then [batch, length, num_heads, head_dim] - query = projection(kernel_init=query_init, name='query')( \ - (inputs_q / depth_scaling) if self.scale_attn_logits else inputs_q) - key = projection(kernel_init=self.kernel_init, name='key')(inputs_kv) - value = projection(kernel_init=self.kernel_init, name='value')(inputs_kv) - - query = with_sharding_constraint(query, ('batch', 'length', 'heads', 'kv')) - key = with_sharding_constraint(key, ('batch', 'length', 'heads', 'kv')) - value = with_sharding_constraint(value, ('batch', 'length', 'heads', 'kv')) - - if decode: - # Detect if we're initializing by absence of existing cache data. - is_initialized = self.has_variable('cache', 'cached_key') - # The key and value have dimension [batch, length, num_heads, head_dim], - # but we cache them as [batch, num_heads, head_dim, length] as a TPU - # fusion optimization. This also enables the "scatter via one-hot - # broadcast" trick, which means we do a one-hot broadcast instead of a - # scatter/gather operations, resulting in a 3-4x speedup in practice. - swap_dims = lambda x: x[:-3] + tuple(x[i] for i in [-2, -1, -3]) - cached_key = self.variable('cache', 'cached_key', jnp.zeros, - swap_dims(key.shape), key.dtype) - cached_value = self.variable('cache', 'cached_value', jnp.zeros, - swap_dims(value.shape), value.dtype) - cache_index = self.variable('cache', 'cache_index', - lambda: jnp.array(0, dtype=jnp.int32)) - if is_initialized: - batch, num_heads, head_dim, length = (cached_key.value.shape) - # During fast autoregressive decoding, we feed one position at a time, - # and cache the keys and values step by step. - # Sanity shape check of cached key against input query. - expected_shape = (batch, 1, num_heads, head_dim) - if expected_shape != query.shape: - raise ValueError('Autoregressive cache shape error, ' - 'expected query shape %s instead got %s.' % - (expected_shape, query.shape)) - - # Create a OHE of the current index. NOTE: the index is increased below. - cur_index = cache_index.value - one_hot_indices = jax.nn.one_hot(cur_index, length, dtype=key.dtype) - # In order to update the key, value caches with the current key and - # value, we move the length axis to the back, similar to what we did for - # the cached ones above. - # Note these are currently the key and value of a single position, since - # we feed one position at a time. - one_token_key = jnp.moveaxis(key, -3, -1) - one_token_value = jnp.moveaxis(value, -3, -1) - # Update key, value caches with our new 1d spatial slices. - # We implement an efficient scatter into the cache via one-hot - # broadcast and addition. - key = cached_key.value + one_token_key * one_hot_indices - value = cached_value.value + one_token_value * one_hot_indices - cached_key.value = key - cached_value.value = value - cache_index.value = cache_index.value + 1 - # Move the keys and values back to their original shapes. - key = jnp.moveaxis(key, -1, -3) - value = jnp.moveaxis(value, -1, -3) - - # Causal mask for cached decoder self-attention: our single query - # position should only attend to those key positions that have already - # been generated and cached, not the remaining zero elements. - mask = combine_masks( - mask, - jnp.broadcast_to( - jnp.arange(length) <= cur_index, - # (1, 1, length) represent (head dim, query length, key length) - # query length is 1 because during decoding we deal with one - # index. - # The same mask is applied to all batch elements and heads. - (batch, 1, 1, length))) - - # Grab the correct relative attention bias during decoding. This is - # only required during single step decoding. - if bias is not None: - # The bias is a full attention matrix, but during decoding we only - # have to take a slice of it. - # This is equivalent to bias[..., cur_index:cur_index+1, :]. - bias = dynamic_vector_slice_in_dim( - jnp.squeeze(bias, axis=0), jnp.reshape(cur_index, (-1)), 1, -2) - - # Convert the boolean attention mask to an attention bias. - if mask is not None: - # attention mask in the form of attention bias - attention_bias = lax.select( - mask > 0, - jnp.full(mask.shape, 0.).astype(self.dtype), - jnp.full(mask.shape, -1e10).astype(self.dtype)) - else: - attention_bias = None - - # Add provided bias term (e.g. relative position embedding). - if bias is not None: - attention_bias = combine_biases(attention_bias, bias) - - dropout_rng = None - if not deterministic and self.dropout_rate > 0.: - dropout_rng = self.make_rng('dropout') - - # Apply attention. - x = dot_product_attention( - query, - key, - value, - bias=attention_bias, - dropout_rng=dropout_rng, - dropout_rate=self.dropout_rate, - deterministic=deterministic, - dtype=self.dtype, - float32_logits=self.float32_logits) - - # Back to the original inputs dimensions. - out = DenseGeneral( - features=inputs_q.shape[-1], # output dim is set to the input dim. - axis=(-2, -1), - kernel_init=self.kernel_init, - kernel_axes=('joined_kv', 'embed'), - dtype=self.dtype, - name='out')( - x) - return out - - -def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]: - # A tuple by convention. len(axes_tuple) then also gives the rank efficiently. - return tuple([ax if ax >= 0 else ndim + ax for ax in axes]) - - -def _canonicalize_tuple(x): - if isinstance(x, Iterable): - return tuple(x) - else: - return (x,) - - -#------------------------------------------------------------------------------ -# DenseGeneral for attention layers. -#------------------------------------------------------------------------------ -class DenseGeneral(nn.Module): - """A linear transformation (without bias) with flexible axes. - - Attributes: - features: tuple with numbers of output features. - axis: tuple with axes to apply the transformation on. - dtype: the dtype of the computation (default: float32). - kernel_init: initializer function for the weight matrix. - """ - features: Union[Iterable[int], int] - axis: Union[Iterable[int], int] = -1 - dtype: DType = jnp.float32 - kernel_init: Initializer = nn.initializers.variance_scaling( - 1.0, 'fan_in', 'truncated_normal') - kernel_axes: Tuple[str, ...] = () - - @nn.compact - def __call__(self, inputs: Array) -> Array: - """Applies a linear transformation to the inputs along multiple dimensions. - - Args: - inputs: The nd-array to be transformed. - - Returns: - The transformed input. - """ - features = _canonicalize_tuple(self.features) - axis = _canonicalize_tuple(self.axis) - - inputs = jnp.asarray(inputs, self.dtype) - axis = _normalize_axes(axis, inputs.ndim) - - kernel_shape = tuple([inputs.shape[ax] for ax in axis]) + features - kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]), - np.prod(features)) - kernel = param_with_axes( - 'kernel', - self.kernel_init, - kernel_param_shape, - jnp.float32, - axes=self.kernel_axes) - kernel = jnp.asarray(kernel, self.dtype) - kernel = jnp.reshape(kernel, kernel_shape) - - contract_ind = tuple(range(0, len(axis))) - return lax.dot_general(inputs, kernel, ((axis, contract_ind), ((), ()))) - - -def _convert_to_activation_function( - fn_or_string: Union[str, Callable]) -> Callable: - """Convert a string to an activation function.""" - if fn_or_string == 'linear': - return lambda x: x - elif isinstance(fn_or_string, str): - return getattr(nn, fn_or_string) - elif callable(fn_or_string): - return fn_or_string - else: - raise ValueError("don't know how to convert %s to an activation function" % - (fn_or_string,)) - - -class MlpBlock(nn.Module): - """Transformer MLP / feed-forward block. - - Attributes: - intermediate_dim: Shared dimension of hidden layers. - activations: Type of activations for each layer. Each element is either - 'linear', a string function name in flax.linen, or a function. - kernel_init: Kernel function, passed to the dense layers. - deterministic: Whether the dropout layers should be deterministic. - intermediate_dropout_rate: Dropout rate used after the intermediate layers. - dtype: Type for the dense layer. - """ - intermediate_dim: int = 2048 - activations: Sequence[Union[str, Callable]] = ('relu',) - kernel_init: Initializer = nn.initializers.variance_scaling( - 1.0, 'fan_in', 'truncated_normal') - intermediate_dropout_rate: float = 0.1 - dtype: Any = jnp.float32 - - @nn.compact - def __call__(self, inputs, decode: bool = False, deterministic: bool = False): - """Applies Transformer MlpBlock module.""" - # Iterate over specified MLP input activation functions. - # e.g. ('relu',) or ('gelu', 'linear') for gated-gelu. - activations = [] - for idx, act_fn in enumerate(self.activations): - dense_name = 'wi' if len(self.activations) == 1 else f'wi_{idx}' - x = DenseGeneral( - self.intermediate_dim, - dtype=self.dtype, - kernel_init=self.kernel_init, - kernel_axes=('embed', 'mlp'), - name=dense_name)( - inputs) - x = _convert_to_activation_function(act_fn)(x) - activations.append(x) - - # Take elementwise product of above intermediate activations. - x = functools.reduce(operator.mul, activations) - # Apply dropout and final dense output projection. - x = nn.Dropout( - rate=self.intermediate_dropout_rate, broadcast_dims=(-2,))( - x, deterministic=deterministic) # Broadcast along length. - x = with_sharding_constraint(x, ('batch', 'length', 'mlp')) - output = DenseGeneral( - inputs.shape[-1], - dtype=self.dtype, - kernel_init=self.kernel_init, - kernel_axes=('mlp', 'embed'), - name='wo')( - x) - return output - - -class Embed(nn.Module): - """A parameterized function from integers [0, n) to d-dimensional vectors. - - Attributes: - num_embeddings: number of embeddings. - features: number of feature dimensions for each embedding. - dtype: the dtype of the embedding vectors (default: float32). - embedding_init: embedding initializer. - one_hot: performs the gather with a one-hot contraction rather than a true - gather. This is currently needed for SPMD partitioning. - """ - num_embeddings: int - features: int - cast_input_dtype: Optional[DType] = None - dtype: DType = jnp.float32 - attend_dtype: Optional[DType] = None - embedding_init: Initializer = default_embed_init - one_hot: bool = False - embedding: Array = dataclasses.field(init=False) - - def setup(self): - self.embedding = param_with_axes( - 'embedding', - self.embedding_init, (self.num_embeddings, self.features), - jnp.float32, - axes=('vocab', 'embed')) - - def __call__(self, inputs: Array) -> Array: - """Embeds the inputs along the last dimension. - - Args: - inputs: input data, all dimensions are considered batch dimensions. - - Returns: - Output which is embedded input data. The output shape follows the input, - with an additional `features` dimension appended. - """ - if self.cast_input_dtype: - inputs = inputs.astype(self.cast_input_dtype) - if not jnp.issubdtype(inputs.dtype, jnp.integer): - raise ValueError('Input type must be an integer or unsigned integer.') - if self.one_hot: - iota = lax.iota(jnp.int32, self.num_embeddings) - one_hot = jnp.array(inputs[..., jnp.newaxis] == iota, dtype=self.dtype) - output = jnp.dot(one_hot, jnp.asarray(self.embedding, self.dtype)) - else: - output = jnp.asarray(self.embedding, self.dtype)[inputs] - output = with_sharding_constraint(output, ('batch', 'length', 'embed')) - return output - - def attend(self, query: Array) -> Array: - """Attend over the embedding using a query array. - - Args: - query: array with last dimension equal the feature depth `features` of the - embedding. - - Returns: - An array with final dim `num_embeddings` corresponding to the batched - inner-product of the array of query vectors against each embedding. - Commonly used for weight-sharing between embeddings and logit transform - in NLP models. - """ - dtype = self.attend_dtype if self.attend_dtype is not None else self.dtype - return jnp.dot(query, jnp.asarray(self.embedding, dtype).T) - - -class RelativePositionBiases(nn.Module): - """Adds T5-style relative positional embeddings to the attention logits. - - Attributes: - num_buckets: Number of buckets to bucket distances between key and query - positions into. - max_distance: Maximum distance before everything is lumped into the last - distance bucket. - num_heads: Number of heads in the attention layer. Each head will get a - different relative position weighting. - dtype: Type of arrays through this module. - embedding_init: initializer for relative embedding table. - """ - num_buckets: int - max_distance: int - num_heads: int - dtype: Any - embedding_init: Callable[..., Array] = nn.linear.default_embed_init - - @staticmethod - def _relative_position_bucket(relative_position, - bidirectional=True, - num_buckets=32, - max_distance=128): - """Translate relative position to a bucket number for relative attention. - - The relative position is defined as memory_position - query_position, i.e. - the distance in tokens from the attending position to the attended-to - position. If bidirectional=False, then positive relative positions are - invalid. - We use smaller buckets for small absolute relative_position and larger - buckets for larger absolute relative_positions. All relative - positions >=max_distance map to the same bucket. All relative - positions <=-max_distance map to the same bucket. This should allow for - more graceful generalization to longer sequences than the model has been - trained on. - - Args: - relative_position: an int32 array - bidirectional: a boolean - whether the attention is bidirectional - num_buckets: an integer - max_distance: an integer - - Returns: - a Tensor with the same shape as relative_position, containing int32 - values in the range [0, num_buckets) - """ - ret = 0 - n = -relative_position - if bidirectional: - num_buckets //= 2 - ret += (n < 0).astype(np.int32) * num_buckets - n = np.abs(n) - else: - n = np.maximum(n, 0) - # now n is in the range [0, inf) - max_exact = num_buckets // 2 - is_small = (n < max_exact) - val_if_large = max_exact + ( - np.log(n.astype(np.float32) / max_exact + np.finfo(np.float32).eps) / - np.log(max_distance / max_exact) * - (num_buckets - max_exact)).astype(np.int32) - val_if_large = np.minimum(val_if_large, num_buckets - 1) - ret += np.where(is_small, n, val_if_large) - return ret - - @nn.compact - def __call__(self, qlen, klen, bidirectional=True): - """Produce relative position embedding attention biases. - - Args: - qlen: attention query length. - klen: attention key length. - bidirectional: whether to allow positive memory-query relative position - embeddings. - - Returns: - output: `(1, len, q_len, k_len)` attention bias - """ - # TODO(levskaya): should we be computing this w. numpy as a program - # constant? - context_position = np.arange(qlen, dtype=jnp.int32)[:, None] - memory_position = np.arange(klen, dtype=jnp.int32)[None, :] - relative_position = memory_position - context_position # shape (qlen, klen) - rp_bucket = self._relative_position_bucket( - relative_position, - bidirectional=bidirectional, - num_buckets=self.num_buckets, - max_distance=self.max_distance) - relative_attention_bias = param_with_axes( - 'rel_embedding', - self.embedding_init, (self.num_heads, self.num_buckets), - jnp.float32, - axes=('heads', 'relpos_buckets')) - - relative_attention_bias = jnp.asarray(relative_attention_bias, self.dtype) - # Instead of using a slow gather, we create a leading-dimension one-hot - # array from rp_bucket and use it to perform the gather-equivalent via a - # contraction, i.e.: - # (num_head, num_buckets) x (num_buckets one-hot, qlen, klen). - # This is equivalent to relative_attention_bias[:, rp_bucket] - bcast_iota = lax.broadcasted_iota(jnp.int32, (self.num_buckets, 1, 1), 0) - rp_bucket_one_hot = jnp.array( - rp_bucket[jnp.newaxis, ...] == bcast_iota, dtype=self.dtype) - # --> shape (qlen, klen, num_heads) - values = lax.dot_general( - relative_attention_bias, - rp_bucket_one_hot, - ( - ((1,), (0,)), # rhs, lhs contracting dims - ((), ()))) # no batched dims - # Add a singleton batch dimension. - # --> shape (1, num_heads, qlen, klen) - return values[jnp.newaxis, ...] - - -#------------------------------------------------------------------------------ -# T5 Layernorm - no subtraction of mean or bias. -#------------------------------------------------------------------------------ -class LayerNorm(nn.Module): - """T5 Layer normalization operating on the last axis of the input data.""" - epsilon: float = 1e-6 - dtype: Any = jnp.float32 - scale_init: Initializer = nn.initializers.ones - - @nn.compact - def __call__(self, x: jnp.ndarray) -> jnp.ndarray: - """Applies layer normalization on the input.""" - x = jnp.asarray(x, jnp.float32) - features = x.shape[-1] - mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True) - y = jnp.asarray(x * lax.rsqrt(mean2 + self.epsilon), self.dtype) - scale = param_with_axes( - 'scale', self.scale_init, (features,), jnp.float32, axes=('embed',)) - - scale = jnp.asarray(scale, self.dtype) - return y * scale - - -#------------------------------------------------------------------------------ -# Mask-making utility functions. -#------------------------------------------------------------------------------ -def make_attention_mask(query_input: Array, - key_input: Array, - pairwise_fn: Callable = jnp.multiply, - extra_batch_dims: int = 0, - dtype: DType = jnp.float32) -> Array: - """Mask-making helper for attention weights. - - In case of 1d inputs (i.e., `[batch, len_q]`, `[batch, len_kv]`, the - attention weights will be `[batch, heads, len_q, len_kv]` and this - function will produce `[batch, 1, len_q, len_kv]`. - - Args: - query_input: a batched, flat input of query_length size - key_input: a batched, flat input of key_length size - pairwise_fn: broadcasting elementwise comparison function - extra_batch_dims: number of extra batch dims to add singleton axes for, none - by default - dtype: mask return dtype - - Returns: - A `[batch, 1, len_q, len_kv]` shaped mask for 1d attention. - """ - # [batch, len_q, len_kv] - mask = pairwise_fn( - # [batch, len_q] -> [batch, len_q, 1] - jnp.expand_dims(query_input, axis=-1), - # [batch, len_q] -> [batch, 1, len_kv] - jnp.expand_dims(key_input, axis=-2)) - - # [batch, 1, len_q, len_kv]. This creates the head dim. - mask = jnp.expand_dims(mask, axis=-3) - mask = jnp.expand_dims(mask, axis=tuple(range(extra_batch_dims))) - return mask.astype(dtype) - - -def make_causal_mask(x: Array, - extra_batch_dims: int = 0, - dtype: DType = jnp.float32) -> Array: - """Make a causal mask for self-attention. - - In case of 1d inputs (i.e., `[batch, len]`, the self-attention weights - will be `[batch, heads, len, len]` and this function will produce a - causal mask of shape `[batch, 1, len, len]`. - - Note that a causal mask does not depend on the values of x; it only depends on - the shape. If x has padding elements, they will not be treated in a special - manner. - - Args: - x: input array of shape `[batch, len]` - extra_batch_dims: number of batch dims to add singleton axes for, none by - default - dtype: mask return dtype - - Returns: - A `[batch, 1, len, len]` shaped causal mask for 1d attention. - """ - idxs = jnp.broadcast_to(jnp.arange(x.shape[-1], dtype=jnp.int32), x.shape) - return make_attention_mask( - idxs, - idxs, - jnp.greater_equal, - extra_batch_dims=extra_batch_dims, - dtype=dtype) - - -def combine_masks(*masks: Optional[Array], dtype: DType = jnp.float32): - """Combine attention masks. - - Args: - *masks: set of attention mask arguments to combine, some can be None. - dtype: final mask dtype - - Returns: - Combined mask, reduced by logical and, returns None if no masks given. - """ - masks = [m for m in masks if m is not None] - if not masks: - return None - assert all(map(lambda x: x.ndim == masks[0].ndim, masks)), ( - f'masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}') - mask, *other_masks = masks - for other_mask in other_masks: - mask = jnp.logical_and(mask, other_mask) - return mask.astype(dtype) - - -def combine_biases(*masks: Optional[Array]): - """Combine attention biases. - - Args: - *masks: set of attention bias arguments to combine, some can be None. - - Returns: - Combined mask, reduced by summation, returns None if no masks given. - """ - masks = [m for m in masks if m is not None] - if not masks: - return None - assert all(map(lambda x: x.ndim == masks[0].ndim, masks)), ( - f'masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}') - mask, *other_masks = masks - for other_mask in other_masks: - mask = mask + other_mask - return mask - - -def make_decoder_mask(decoder_target_tokens: Array, - dtype: DType, - decoder_causal_attention: Optional[Array] = None, - decoder_segment_ids: Optional[Array] = None) -> Array: - """Compute the self-attention mask for a decoder. - - Decoder mask is formed by combining a causal mask, a padding mask and an - optional packing mask. If decoder_causal_attention is passed, it makes the - masking non-causal for positions that have value of 1. - - A prefix LM is applied to a dataset which has a notion of "inputs" and - "targets", e.g., a machine translation task. The inputs and targets are - concatenated to form a new target. `decoder_target_tokens` is the concatenated - decoder output tokens. - - The "inputs" portion of the concatenated sequence can attend to other "inputs" - tokens even for those at a later time steps. In order to control this - behavior, `decoder_causal_attention` is necessary. This is a binary mask with - a value of 1 indicating that the position belonged to "inputs" portion of the - original dataset. - - Example: - - Suppose we have a dataset with two examples. - - ds = [{"inputs": [6, 7], "targets": [8]}, - {"inputs": [3, 4], "targets": [5]}] - - After the data preprocessing with packing, the two examples are packed into - one example with the following three fields (some fields are skipped for - simplicity). - - decoder_target_tokens = [[6, 7, 8, 3, 4, 5, 0]] - decoder_segment_ids = [[1, 1, 1, 2, 2, 2, 0]] - decoder_causal_attention = [[1, 1, 0, 1, 1, 0, 0]] - - where each array has [batch, length] shape with batch size being 1. Then, - this function computes the following mask. - - mask = [[[[1, 1, 0, 0, 0, 0, 0], - [1, 1, 0, 0, 0, 0, 0], - [1, 1, 1, 0, 0, 0, 0], - [0, 0, 0, 1, 1, 0, 0], - [0, 0, 0, 1, 1, 0, 0], - [0, 0, 0, 1, 1, 1, 0], - [0, 0, 0, 0, 0, 0, 0]]]] - - mask[b, 1, :, :] represents the mask for the example `b` in the batch. - Because mask is for a self-attention layer, the mask's shape is a square of - shape [query length, key length]. - - mask[b, 1, i, j] = 1 means that the query token at position i can attend to - the key token at position j. - - Args: - decoder_target_tokens: decoder output tokens. [batch, length] - dtype: dtype of the output mask. - decoder_causal_attention: a binary mask indicating which position should - only attend to earlier positions in the sequence. Others will attend - bidirectionally. [batch, length] - decoder_segment_ids: decoder segmentation info for packed examples. [batch, - length] - - Returns: - the combined decoder mask. - """ - masks = [] - # The same mask is applied to all attention heads. So the head dimension is 1, - # i.e., the mask will be broadcast along the heads dim. - # [batch, 1, length, length] - causal_mask = make_causal_mask(decoder_target_tokens, dtype=dtype) - - # Positions with value 1 in `decoder_causal_attneition` can attend - # bidirectionally. - if decoder_causal_attention is not None: - # [batch, 1, length, length] - inputs_mask = make_attention_mask( - decoder_causal_attention, - decoder_causal_attention, - jnp.logical_and, - dtype=dtype) - masks.append(jnp.logical_or(causal_mask, inputs_mask).astype(dtype)) - else: - masks.append(causal_mask) - - # Padding mask. - masks.append( - make_attention_mask( - decoder_target_tokens > 0, decoder_target_tokens > 0, dtype=dtype)) - - # Packing mask - if decoder_segment_ids is not None: - masks.append( - make_attention_mask( - decoder_segment_ids, decoder_segment_ids, jnp.equal, dtype=dtype)) - - return combine_masks(*masks, dtype=dtype) diff --git a/t5x-main/t5x/contrib/gpu/t5/layers_test.py b/t5x-main/t5x/contrib/gpu/t5/layers_test.py deleted file mode 100644 index d04a50963c74a9cd6cf361d179079e59130feda2..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/t5/layers_test.py +++ /dev/null @@ -1,620 +0,0 @@ -# Copyright 2022 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for attention classes.""" - -import dataclasses -from typing import Optional -from unittest import mock - -from absl.testing import absltest -from absl.testing import parameterized -from flax import linen as nn -from flax.core import freeze -from flax.linen import partitioning as nn_partitioning -import jax -from jax import random -from jax.nn import initializers -import jax.numpy as jnp -import numpy as np -from t5x.contrib.gpu.t5 import layers - -# Parse absl flags test_srcdir and test_tmpdir. -jax.config.parse_flags_with_absl() - -Array = jnp.ndarray -AxisMetadata = nn_partitioning.AxisMetadata # pylint: disable=invalid-name - - -class SelfAttention(layers.MultiHeadDotProductAttention): - """Self-attention special case of multi-head dot-product attention.""" - - @nn.compact - def __call__(self, - inputs_q: Array, - mask: Optional[Array] = None, - bias: Optional[Array] = None, - deterministic: bool = False): - return super().__call__( - inputs_q, inputs_q, mask, bias, deterministic=deterministic) - - -@dataclasses.dataclass(frozen=True) -class SelfAttentionArgs: - num_heads: int = 1 - batch_size: int = 2 - # qkv_features: int = 3 - head_dim: int = 3 - # out_features: int = 4 - q_len: int = 5 - features: int = 6 - dropout_rate: float = 0.1 - deterministic: bool = False - decode: bool = False - float32_logits: bool = False - - def __post_init__(self): - # If we are doing decoding, the query length should be 1, because are doing - # autoregressive decoding where we feed one position at a time. - assert not self.decode or self.q_len == 1 - - def init_args(self): - return dict( - num_heads=self.num_heads, - head_dim=self.head_dim, - dropout_rate=self.dropout_rate, - float32_logits=self.float32_logits) - - def apply_args(self): - inputs_q = jnp.ones((self.batch_size, self.q_len, self.features)) - mask = jnp.ones((self.batch_size, self.num_heads, self.q_len, self.q_len)) - bias = jnp.ones((self.batch_size, self.num_heads, self.q_len, self.q_len)) - return { - 'inputs_q': inputs_q, - 'mask': mask, - 'bias': bias, - 'deterministic': self.deterministic - } - - -class AttentionTest(parameterized.TestCase): - - def test_dot_product_attention_shape(self): - # This test only checks for shape but tries to make sure all code paths are - # reached. - dropout_rng = random.PRNGKey(0) - batch_size, num_heads, q_len, kv_len, qk_depth, v_depth = 1, 2, 3, 4, 5, 6 - - query = jnp.ones((batch_size, q_len, num_heads, qk_depth)) - key = jnp.ones((batch_size, kv_len, num_heads, qk_depth)) - value = jnp.ones((batch_size, kv_len, num_heads, v_depth)) - bias = jnp.ones((batch_size, num_heads, q_len, kv_len)) - - args = dict( - query=query, - key=key, - value=value, - bias=bias, - dropout_rng=dropout_rng, - dropout_rate=0.5, - deterministic=False, - ) - - output = layers.dot_product_attention(**args) - self.assertEqual(output.shape, (batch_size, q_len, num_heads, v_depth)) - - def test_make_attention_mask_multiply_pairwise_fn(self): - decoder_target_tokens = jnp.array([[7, 0, 0], [8, 5, 0]]) - attention_mask = layers.make_attention_mask( - decoder_target_tokens > 0, decoder_target_tokens > 0, dtype=jnp.int32) - expected0 = jnp.array([[1, 0, 0], [0, 0, 0], [0, 0, 0]]) - expected1 = jnp.array([[1, 1, 0], [1, 1, 0], [0, 0, 0]]) - self.assertEqual(attention_mask.shape, (2, 1, 3, 3)) - np.testing.assert_array_equal(attention_mask[0, 0], expected0) - np.testing.assert_array_equal(attention_mask[1, 0], expected1) - - def test_make_attention_mask_equal_pairwise_fn(self): - segment_ids = jnp.array([[1, 1, 2, 2, 2, 0], [1, 1, 1, 2, 0, 0]]) - attention_mask = layers.make_attention_mask( - segment_ids, segment_ids, pairwise_fn=jnp.equal, dtype=jnp.int32) - # Padding is not treated in a special way. So they need to be zeroed out - # separately. - expected0 = jnp.array([[1, 1, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0], - [0, 0, 1, 1, 1, 0], [0, 0, 1, 1, 1, 0], - [0, 0, 1, 1, 1, 0], [0, 0, 0, 0, 0, 1]]) - expected1 = jnp.array([[1, 1, 1, 0, 0, 0], [1, 1, 1, 0, 0, 0], - [1, 1, 1, 0, 0, 0], [0, 0, 0, 1, 0, 0], - [0, 0, 0, 0, 1, 1], [0, 0, 0, 0, 1, 1]]) - self.assertEqual(attention_mask.shape, (2, 1, 6, 6)) - np.testing.assert_array_equal(attention_mask[0, 0], expected0) - np.testing.assert_array_equal(attention_mask[1, 0], expected1) - - def test_make_causal_mask_with_padding(self): - x = jnp.array([[7, 0, 0], [8, 5, 0]]) - y = layers.make_causal_mask(x) - self.assertEqual(y.shape, (2, 1, 3, 3)) - # Padding is not treated in a special way. So they need to be zeroed out - # separately. - expected_y = jnp.array([[[1., 0., 0.], [1., 1., 0.], [1., 1., 1.]]], - jnp.float32) - np.testing.assert_allclose(y[0], expected_y) - np.testing.assert_allclose(y[1], expected_y) - - def test_make_causal_mask_extra_batch_dims(self): - x = jnp.ones((3, 3, 5)) - y = layers.make_causal_mask(x, extra_batch_dims=2) - self.assertEqual(y.shape, (1, 1, 3, 3, 1, 5, 5)) - - def test_make_causal_mask(self): - x = jnp.ones((1, 3)) - y = layers.make_causal_mask(x) - self.assertEqual(y.shape, (1, 1, 3, 3)) - expected_y = jnp.array([[[[1., 0., 0.], [1., 1., 0.], [1., 1., 1.]]]], - jnp.float32) - np.testing.assert_allclose(y, expected_y) - - def test_combine_masks(self): - masks = [ - jnp.array([0, 1, 0, 1], jnp.float32), None, - jnp.array([1, 1, 1, 1], jnp.float32), - jnp.array([1, 1, 1, 0], jnp.float32) - ] - y = layers.combine_masks(*masks) - np.testing.assert_allclose(y, jnp.array([0, 1, 0, 0], jnp.float32)) - - def test_combine_biases(self): - masks = [ - jnp.array([0, 1, 0, 1], jnp.float32), None, - jnp.array([0, 1, 1, 1], jnp.float32), - jnp.array([0, 1, 1, 0], jnp.float32) - ] - y = layers.combine_biases(*masks) - np.testing.assert_allclose(y, jnp.array([0, 3, 2, 2], jnp.float32)) - - def test_make_decoder_mask_lm_unpacked(self): - decoder_target_tokens = jnp.array([6, 7, 3, 0]) - mask = layers.make_decoder_mask( - decoder_target_tokens=decoder_target_tokens, dtype=jnp.float32) - expected_mask = jnp.array([[[1, 0, 0, 0], [1, 1, 0, 0], [1, 1, 1, 0], - [0, 0, 0, 0]]]) - np.testing.assert_array_equal(mask, expected_mask) - - def test_make_decoder_mask_lm_packed(self): - decoder_target_tokens = jnp.array([[6, 7, 3, 4, 5, 0]]) - decoder_segment_ids = jnp.array([[1, 1, 1, 2, 2, 0]]) - mask = layers.make_decoder_mask( - decoder_target_tokens=decoder_target_tokens, - dtype=jnp.float32, - decoder_segment_ids=decoder_segment_ids) - expected_mask = jnp.array([[[[1, 0, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0], - [1, 1, 1, 0, 0, 0], [0, 0, 0, 1, 0, 0], - [0, 0, 0, 1, 1, 0], [0, 0, 0, 0, 0, 0]]]]) - np.testing.assert_array_equal(mask, expected_mask) - - def test_make_decoder_mask_prefix_lm_unpacked(self): - decoder_target_tokens = jnp.array([[5, 6, 7, 3, 4, 0]]) - decoder_causal_attention = jnp.array([[1, 1, 1, 0, 0, 0]]) - mask = layers.make_decoder_mask( - decoder_target_tokens=decoder_target_tokens, - dtype=jnp.float32, - decoder_causal_attention=decoder_causal_attention) - expected_mask = jnp.array( - [[[[1, 1, 1, 0, 0, 0], [1, 1, 1, 0, 0, 0], [1, 1, 1, 0, 0, 0], - [1, 1, 1, 1, 0, 0], [1, 1, 1, 1, 1, 0], [0, 0, 0, 0, 0, 0]]]], - dtype=jnp.float32) - np.testing.assert_array_equal(mask, expected_mask) - - def test_make_decoder_mask_prefix_lm_packed(self): - decoder_target_tokens = jnp.array([[5, 6, 7, 8, 3, 4, 0]]) - decoder_segment_ids = jnp.array([[1, 1, 1, 2, 2, 2, 0]]) - decoder_causal_attention = jnp.array([[1, 1, 0, 1, 1, 0, 0]]) - mask = layers.make_decoder_mask( - decoder_target_tokens=decoder_target_tokens, - dtype=jnp.float32, - decoder_causal_attention=decoder_causal_attention, - decoder_segment_ids=decoder_segment_ids) - expected_mask = jnp.array([[[[1, 1, 0, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0, 0], - [1, 1, 1, 0, 0, 0, 0], [0, 0, 0, 1, 1, 0, 0], - [0, 0, 0, 1, 1, 0, 0], [0, 0, 0, 1, 1, 1, 0], - [0, 0, 0, 0, 0, 0, 0]]]]) - np.testing.assert_array_equal(mask, expected_mask) - - def test_make_decoder_mask_prefix_lm_unpacked_multiple_elements(self): - decoder_target_tokens = jnp.array([[6, 7, 3, 0], [4, 5, 0, 0]]) - decoder_causal_attention = jnp.array([[1, 1, 0, 0], [1, 0, 0, 0]]) - mask = layers.make_decoder_mask( - decoder_target_tokens=decoder_target_tokens, - dtype=jnp.float32, - decoder_causal_attention=decoder_causal_attention) - expected_mask0 = jnp.array([[1, 1, 0, 0], [1, 1, 0, 0], [1, 1, 1, 0], - [0, 0, 0, 0]]) - expected_mask1 = jnp.array([[1, 0, 0, 0], [1, 1, 0, 0], [0, 0, 0, 0], - [0, 0, 0, 0]]) - self.assertEqual(mask.shape, (2, 1, 4, 4)) - np.testing.assert_array_equal(mask[0, 0], expected_mask0) - np.testing.assert_array_equal(mask[1, 0], expected_mask1) - - def test_make_decoder_mask_composite_causal_attention(self): - decoder_target_tokens = jnp.array([[6, 7, 3, 4, 8, 9, 0]]) - decoder_causal_attention = jnp.array([[1, 1, 0, 0, 1, 1, 0]]) - mask = layers.make_decoder_mask( - decoder_target_tokens=decoder_target_tokens, - dtype=jnp.float32, - decoder_causal_attention=decoder_causal_attention) - expected_mask0 = jnp.array([[1, 1, 0, 0, 1, 1, 0], [1, 1, 0, 0, 1, 1, 0], - [1, 1, 1, 0, 0, 0, 0], [1, 1, 1, 1, 0, 0, 0], - [1, 1, 1, 1, 1, 1, 0], [1, 1, 1, 1, 1, 1, 0], - [0, 0, 0, 0, 0, 0, 0]]) - - self.assertEqual(mask.shape, (1, 1, 7, 7)) - np.testing.assert_array_equal(mask[0, 0], expected_mask0) - - def test_make_decoder_mask_composite_causal_attention_packed(self): - decoder_target_tokens = jnp.array([[6, 7, 3, 4, 8, 9, 2, 3, 4]]) - decoder_segment_ids = jnp.array([[1, 1, 1, 1, 1, 1, 2, 2, 2]]) - decoder_causal_attention = jnp.array([[1, 1, 0, 0, 1, 1, 1, 1, 0]]) - mask = layers.make_decoder_mask( - decoder_target_tokens=decoder_target_tokens, - dtype=jnp.float32, - decoder_causal_attention=decoder_causal_attention, - decoder_segment_ids=decoder_segment_ids) - expected_mask0 = jnp.array([[1, 1, 0, 0, 1, 1, 0, 0, 0], - [1, 1, 0, 0, 1, 1, 0, 0, 0], - [1, 1, 1, 0, 0, 0, 0, 0, 0], - [1, 1, 1, 1, 0, 0, 0, 0, 0], - [1, 1, 1, 1, 1, 1, 0, 0, 0], - [1, 1, 1, 1, 1, 1, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 1, 1, 0], - [0, 0, 0, 0, 0, 0, 1, 1, 0], - [0, 0, 0, 0, 0, 0, 1, 1, 1]]) - - self.assertEqual(mask.shape, (1, 1, 9, 9)) - np.testing.assert_array_equal(mask[0, 0], expected_mask0) - - @parameterized.parameters({'f': 20}, {'f': 22}) - def test_multihead_dot_product_attention(self, f): - # b: batch, f: emb_dim, q: q_len, k: kv_len, h: num_head, d: head_dim - b, q, h, d, k = 2, 3, 4, 5, 6 - - base_args = SelfAttentionArgs(num_heads=h, head_dim=d, dropout_rate=0) - args = base_args.init_args() - - np.random.seed(0) - inputs_q = np.random.randn(b, q, f) - inputs_kv = np.random.randn(b, k, f) - - # Projection: [b, q, f] -> [b, q, h, d] - # So the kernels have to be [f, h, d] - query_kernel = np.random.randn(f, h, d) - key_kernel = np.random.randn(f, h, d) - value_kernel = np.random.randn(f, h, d) - # `out` calculation: [b, q, h, d] -> [b, q, f] - # So kernel has to be [h, d, f] - out_kernel = np.random.randn(h, d, f) - - params = { - 'query': { - 'kernel': query_kernel.reshape(f, -1) - }, - 'key': { - 'kernel': key_kernel.reshape(f, -1) - }, - 'value': { - 'kernel': value_kernel.reshape(f, -1) - }, - 'out': { - 'kernel': out_kernel.reshape(-1, f) - } - } - y = layers.MultiHeadDotProductAttention(**args).apply( - {'params': freeze(params)}, inputs_q, inputs_kv) - - query = np.einsum('bqf,fhd->bqhd', inputs_q, query_kernel) - key = np.einsum('bkf,fhd->bkhd', inputs_kv, key_kernel) - value = np.einsum('bkf,fhd->bkhd', inputs_kv, value_kernel) - logits = np.einsum('bqhd,bkhd->bhqk', query, key) - weights = nn.softmax(logits, axis=-1) - combined_value = np.einsum('bhqk,bkhd->bqhd', weights, value) - y_expected = np.einsum('bqhd,hdf->bqf', combined_value, out_kernel) - np.testing.assert_allclose(y, y_expected, rtol=1e-5, atol=1e-5) - - def test_multihead_dot_product_attention_caching(self): - # b: batch, f: qkv_features, k: kv_len, h: num_head, d: head_dim - b, h, d, k = 2, 3, 4, 5 - f = h * d - - base_args = SelfAttentionArgs(num_heads=h, head_dim=d, dropout_rate=0) - args = base_args.init_args() - - cache = { - 'cached_key': np.zeros((b, h, d, k)), - 'cached_value': np.zeros((b, h, d, k)), - 'cache_index': np.array(0) - } - inputs_q = np.random.randn(b, 1, f) - inputs_kv = np.random.randn(b, 1, f) - - # Mock dense general such that q, k, v projections are replaced by simple - # reshaping. - def mock_dense_general(self, x, **kwargs): # pylint: disable=unused-argument - return x.reshape(b, -1, h, d) - - with mock.patch.object( - layers.DenseGeneral, '__call__', new=mock_dense_general): - _, mutated = layers.MultiHeadDotProductAttention(**args).apply( - {'cache': freeze(cache)}, - inputs_q, - inputs_kv, - decode=True, - mutable=['cache']) - updated_cache = mutated['cache'] - - # Perform the same mocked projection to generate the expected cache. - # (key|value): [b, 1, h, d] - key = mock_dense_general(None, inputs_kv) - value = mock_dense_general(None, inputs_kv) - - # cached_(key|value): [b, h, d, k] - cache['cached_key'][:, :, :, 0] = key[:, 0, :, :] - cache['cached_value'][:, :, :, 0] = value[:, 0, :, :] - cache['cache_index'] = np.array(1) - for name, array in cache.items(): - np.testing.assert_allclose(array, updated_cache[name]) - - def test_dot_product_attention(self): - # b: batch, f: emb_dim, q: q_len, k: kv_len, h: num_head, d: head_dim - b, q, h, d, k = 2, 3, 4, 5, 6 - np.random.seed(0) - query = np.random.randn(b, q, h, d) - key = np.random.randn(b, k, h, d) - value = np.random.randn(b, k, h, d) - bias = np.random.randn(b, h, q, k) - attn_out = layers.dot_product_attention(query, key, value, bias=bias) - logits = np.einsum('bqhd,bkhd->bhqk', query, key) - weights = jax.nn.softmax(logits + bias, axis=-1) - expected = np.einsum('bhqk,bkhd->bqhd', weights, value) - np.testing.assert_allclose(attn_out, expected, atol=1e-6) - - -class EmbeddingTest(parameterized.TestCase): - - def test_embedder_raises_exception_for_incorrect_input_type(self): - """Tests that inputs are integers and that an exception is raised if not.""" - embed = layers.Embed(num_embeddings=10, features=5) - inputs = np.expand_dims(np.arange(5, dtype=np.int64), 1) - variables = embed.init(jax.random.PRNGKey(0), inputs) - bad_inputs = inputs.astype(np.float32) - with self.assertRaisesRegex( - ValueError, 'Input type must be an integer or unsigned integer.'): - _ = embed.apply(variables, bad_inputs) - - @parameterized.named_parameters( - { - 'testcase_name': 'with_ones', - 'init_fn': jax.nn.initializers.ones, - 'num_embeddings': 10, - 'features': 5, - 'matrix_sum': 5 * 10, - }, { - 'testcase_name': 'with_zeros', - 'init_fn': jax.nn.initializers.zeros, - 'num_embeddings': 10, - 'features': 5, - 'matrix_sum': 0, - }) - def test_embedding_initializes_correctly(self, init_fn, num_embeddings, - features, matrix_sum): - """Tests if the Embed class initializes with the requested initializer.""" - embed = layers.Embed( - num_embeddings=num_embeddings, - features=features, - embedding_init=init_fn) - inputs = np.expand_dims(np.arange(5, dtype=np.int64), 1) - variables = embed.init(jax.random.PRNGKey(0), inputs) - embedding_matrix = variables['params']['embedding'] - self.assertEqual(int(np.sum(embedding_matrix)), matrix_sum) - - def test_embedding_matrix_shape(self): - """Tests that the embedding matrix has the right shape.""" - num_embeddings = 10 - features = 5 - embed = layers.Embed(num_embeddings=num_embeddings, features=features) - inputs = np.expand_dims(np.arange(features, dtype=np.int64), 1) - variables = embed.init(jax.random.PRNGKey(0), inputs) - embedding_matrix = variables['params']['embedding'] - self.assertEqual((num_embeddings, features), embedding_matrix.shape) - - def test_embedding_attend(self): - """Tests that attending with ones returns sum of embedding vectors.""" - features = 5 - embed = layers.Embed(num_embeddings=10, features=features) - inputs = np.array([[1]], dtype=np.int64) - variables = embed.init(jax.random.PRNGKey(0), inputs) - query = np.ones(features, dtype=np.float32) - result = embed.apply(variables, query, method=embed.attend) - expected = np.sum(variables['params']['embedding'], -1) - np.testing.assert_array_almost_equal(result, expected) - - -class DenseTest(parameterized.TestCase): - - def test_dense_general_no_bias(self): - rng = random.PRNGKey(0) - x = jnp.ones((1, 3)) - model = layers.DenseGeneral( - features=4, - kernel_init=initializers.ones, - ) - y, _ = model.init_with_output(rng, x) - self.assertEqual(y.shape, (1, 4)) - np.testing.assert_allclose(y, np.full((1, 4), 3.)) - - def test_dense_general_two_features(self): - rng = random.PRNGKey(0) - x = jnp.ones((1, 3)) - model = layers.DenseGeneral( - features=(2, 2), - kernel_init=initializers.ones, - ) - y, _ = model.init_with_output(rng, x) - # We transform the last input dimension to two output dimensions (2, 2). - np.testing.assert_allclose(y, np.full((1, 2, 2), 3.)) - - def test_dense_general_two_axes(self): - rng = random.PRNGKey(0) - x = jnp.ones((1, 2, 2)) - model = layers.DenseGeneral( - features=3, - axis=(-2, 2), # Note: this is the same as (1, 2). - kernel_init=initializers.ones, - ) - y, _ = model.init_with_output(rng, x) - # We transform the last two input dimensions (2, 2) to one output dimension. - np.testing.assert_allclose(y, np.full((1, 3), 4.)) - - def test_mlp_same_out_dim(self): - module = layers.MlpBlock( - intermediate_dim=4, - activations=('relu',), - kernel_init=nn.initializers.xavier_uniform(), - dtype=jnp.float32, - ) - inputs = np.array( - [ - # Batch 1. - [[1, 1], [1, 1], [1, 2]], - # Batch 2. - [[2, 2], [3, 1], [2, 2]], - ], - dtype=np.float32) - params = module.init(random.PRNGKey(0), inputs, deterministic=True) - self.assertEqual( - jax.tree_map(lambda a: a.tolist(), params), { - 'params': { - 'wi': { - 'kernel': [[ - -0.8675811290740967, 0.08417510986328125, - 0.022586345672607422, -0.9124102592468262 - ], - [ - -0.19464373588562012, 0.49809837341308594, - 0.7808468341827393, 0.9267289638519287 - ]], - }, - 'wo': { - 'kernel': [[0.01154780387878418, 0.1397249698638916], - [0.974980354309082, 0.5903260707855225], - [-0.05997943878173828, 0.616570234298706], - [0.2934272289276123, 0.8181164264678955]], - }, - }, - 'params_axes': { - 'wi': { - 'kernel_axes': AxisMetadata(names=('embed', 'mlp')), - }, - 'wo': { - 'kernel_axes': AxisMetadata(names=('mlp', 'embed')), - }, - }, - }) - result = module.apply(params, inputs, deterministic=True) - np.testing.assert_allclose( - result.tolist(), - [[[0.5237172245979309, 0.8508185744285583], - [0.5237172245979309, 0.8508185744285583], - [1.2344461679458618, 2.3844780921936035]], - [[1.0474344491958618, 1.7016371488571167], - [0.6809444427490234, 0.9663378596305847], - [1.0474344491958618, 1.7016371488571167]]], - rtol=1e-6, - ) - - -class RelativePositionBiasesTest(absltest.TestCase): - - def setUp(self): - self.num_heads = 3 - self.query_len = 5 - self.key_len = 7 - self.relative_attention = layers.RelativePositionBiases( - num_buckets=12, - max_distance=10, - num_heads=3, - dtype=jnp.float32, - ) - super(RelativePositionBiasesTest, self).setUp() - - def test_relative_attention_bidirectional_params(self): - """Tests that bidirectional relative position biases have expected params.""" - params = self.relative_attention.init( - random.PRNGKey(0), self.query_len, self.key_len, bidirectional=True) - param_shapes = jax.tree_map(lambda x: x.shape, params) - self.assertEqual( - param_shapes, { - 'params': { - 'rel_embedding': (3, 12), - }, - 'params_axes': { - 'rel_embedding_axes': - AxisMetadata(names=('heads', 'relpos_buckets')), - } - }) - - def test_regression_relative_attention_bidirectional_values(self): - """Tests that bidirectional relative position biases match expected values. - - See top docstring note on matching T5X behavior for these regression tests. - """ - outputs, unused_params = self.relative_attention.init_with_output( - random.PRNGKey(0), self.query_len, self.key_len, bidirectional=True) - self.assertEqual(outputs.shape, - (1, self.num_heads, self.query_len, self.key_len)) - self.assertAlmostEqual(outputs[0, 0, 0, 0], 0.55764728, places=5) - self.assertAlmostEqual(outputs[0, 1, 2, 1], -0.10935841, places=5) - self.assertAlmostEqual(outputs[0, 1, 4, 6], 0.14510104, places=5) - self.assertAlmostEqual(outputs[0, 2, 4, 6], -0.36783996, places=5) - - def test_relative_attention_unidirectional_params(self): - """Tests that unidirectional relative position biases have expected params.""" - params = self.relative_attention.init( - random.PRNGKey(0), self.query_len, self.key_len, bidirectional=False) - param_shapes = jax.tree_map(lambda x: x.shape, params) - self.assertEqual( - param_shapes, { - 'params': { - 'rel_embedding': (3, 12), - }, - 'params_axes': { - 'rel_embedding_axes': - AxisMetadata(names=('heads', 'relpos_buckets')), - } - }) - - def test_regression_relative_attention_unidirectional_values(self): - """Tests that unidirectional relative position biases match expected values. - - See top docstring note on matching T5X behavior for these regression tests. - """ - outputs, unused_params = self.relative_attention.init_with_output( - random.PRNGKey(0), self.query_len, self.key_len, bidirectional=False) - self.assertEqual(outputs.shape, - (1, self.num_heads, self.query_len, self.key_len)) - self.assertAlmostEqual(outputs[0, 0, 0, 0], 0.55764728, places=5) - self.assertAlmostEqual(outputs[0, 1, 2, 1], -0.10935841, places=5) - self.assertAlmostEqual(outputs[0, 1, 4, 6], -0.13101986, places=5) - self.assertAlmostEqual(outputs[0, 2, 4, 6], 0.39296466, places=5) - - -if __name__ == '__main__': - absltest.main() diff --git a/t5x-main/t5x/contrib/gpu/t5/local_tiny.gin b/t5x-main/t5x/contrib/gpu/t5/local_tiny.gin deleted file mode 100644 index bc3b80bcb4d3ddb1027ab2f20f977f182530a94b..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/t5/local_tiny.gin +++ /dev/null @@ -1,68 +0,0 @@ -# A gin file to make the Transformer models tiny for faster local testing. -# -# When testing locally with CPU, there are a few things that we need. -# - tiny model size -# - small enough batch size -# - small sequence length -# - determinstic dataset pipeline -# -# This gin file adds such configs. To use this gin file, add it on top of the -# existing full-scale gin files. The ordering of the gin file matters. So this -# should be added after all the other files are added to override the same -# configurables. - -from __gin__ import dynamic_registration - -from t5x import partitioning -from t5x import trainer -from t5x import utils -from t5x.contrib.gpu.t5 import network - -import __main__ as train_script - -train_script.train.random_seed = 42 # dropout seed -train/utils.DatasetConfig.seed = 42 # dataset seed - -TASK_FEATURE_LENGTHS = {"inputs": 8, "targets": 8} -LABEL_SMOOTHING = 0.0 -TRAIN_STEPS = 3 - -# Network specification overrides -network.Transformer.config = @network.T5Config() -network.T5Config: - vocab_size = 32128 # vocab size rounded to a multiple of 128 for TPU efficiency - dtype = 'bfloat16' - emb_dim = 8 - num_heads = 4 - num_encoder_layers = 2 - num_decoder_layers = 2 - head_dim = 3 - mlp_dim = 16 - mlp_activations = ('gelu', 'linear') - dropout_rate = 0.0 - logits_via_embedding = False - -train/utils.DatasetConfig: - batch_size = 8 - shuffle = False - -train_eval/utils.DatasetConfig.batch_size = 8 - -train_script.train: - eval_period = 3 - eval_steps = 3 - -trainer.Trainer.num_microbatches = 0 -partitioning.PjitPartitioner: - num_partitions = 1 - model_parallel_submesh = None - -utils.CheckpointConfig: - restore = None - save = None - -infer_eval/utils.DatasetConfig.task_feature_lengths = %TASK_FEATURE_LENGTHS - - -# DISABLE INFERENCE EVAL -# train_script.train.infer_eval_dataset_cfg = None diff --git a/t5x-main/t5x/contrib/gpu/t5/mt5/__init__.py b/t5x-main/t5x/contrib/gpu/t5/mt5/__init__.py deleted file mode 100644 index da022c16301721a096a208e8bdb2a71bb87f9788..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/t5/mt5/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright 2022 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This empty file is needed for loading the gin files in this directory. diff --git a/t5x-main/t5x/contrib/gpu/t5/mt5/base.gin b/t5x-main/t5x/contrib/gpu/t5/mt5/base.gin deleted file mode 100644 index 73e56b37bc336b48305ed17b20fae35f1e0c8e65..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/t5/mt5/base.gin +++ /dev/null @@ -1,55 +0,0 @@ -# mT5 Base model. -from __gin__ import dynamic_registration - -import seqio -from t5x import adafactor -from t5x import models -from t5x.contrib.gpu.t5 import network - -# ------------------- Loss HParam ---------------------------------------------- -Z_LOSS = 0.0001 -LABEL_SMOOTHING = 0.0 -# NOTE: When fine-tuning the public T5 checkpoints (trained in T5 MeshTF) -# the loss normalizing factor should be set to pretraining batch_size * -# target_token_length. -LOSS_NORMALIZING_FACTOR = None -# Dropout should be specified in the "run" files -DROPOUT_RATE = %gin.REQUIRED - -# Vocabulary (shared by encoder and decoder) -VOCABULARY = @seqio.SentencePieceVocabulary() -seqio.SentencePieceVocabulary.sentencepiece_model_file = "gs://t5-data/vocabs/mc4.250000.100extra/sentencepiece.model" - -# ------------------- Optimizer ------------------------------------------------ -# `learning_rate` is set by `Trainer.learning_rate_fn`. -OPTIMIZER = @adafactor.Adafactor() -adafactor.Adafactor: - decay_rate = 0.8 - step_offset = 0 - logical_factor_rules = @adafactor.standard_logical_factor_rules() - -# ------------------- Model ---------------------------------------------------- -MODEL = @models.EncoderDecoderModel() -models.EncoderDecoderModel: - module = @network.Transformer() - input_vocabulary = %VOCABULARY - output_vocabulary = %VOCABULARY - optimizer_def = %OPTIMIZER - z_loss = %Z_LOSS - label_smoothing = %LABEL_SMOOTHING - loss_normalizing_factor = %LOSS_NORMALIZING_FACTOR - -# ------------------- Network specification ------------------------------------ -network.Transformer.config = @network.T5Config() -network.T5Config: - vocab_size = 250112 # vocab size rounded to a multiple of 128 for TPU efficiency - dtype = 'bfloat16' - emb_dim = 768 - num_heads = 12 - num_encoder_layers = 12 - num_decoder_layers = 12 - head_dim = 64 - mlp_dim = 2048 - mlp_activations = ('gelu', 'linear') - dropout_rate = %DROPOUT_RATE - logits_via_embedding = False diff --git a/t5x-main/t5x/contrib/gpu/t5/mt5/large.gin b/t5x-main/t5x/contrib/gpu/t5/mt5/large.gin deleted file mode 100644 index 5d0c085907c54c04c2665a6f4f8a500908e18e10..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/t5/mt5/large.gin +++ /dev/null @@ -1,13 +0,0 @@ -# mT5 Large model. - -include 't5x/contrib/gpu/t5/mt5/base.gin' # imports vocab, optimizer and model. - -# ------------------- Network specification overrides -------------------------- -network.Transformer.config = @network.T5Config() -network.T5Config: - emb_dim = 1024 - num_heads = 16 - num_encoder_layers = 24 - num_decoder_layers = 24 - head_dim = 64 - mlp_dim = 2816 diff --git a/t5x-main/t5x/contrib/gpu/t5/mt5/small.gin b/t5x-main/t5x/contrib/gpu/t5/mt5/small.gin deleted file mode 100644 index 9f4998c2f66ae089faa590606ede21e0fdcc7937..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/t5/mt5/small.gin +++ /dev/null @@ -1,13 +0,0 @@ -# mT5 Small model. - -include 't5x/contrib/gpu/t5/mt5/base.gin' # imports vocab, optimizer and model. - -# ------------------- Network specification overrides -------------------------- -network.Transformer.config = @network.T5Config() -network.T5Config: - emb_dim = 512 - num_heads = 6 - num_encoder_layers = 8 - num_decoder_layers = 8 - head_dim = 64 - mlp_dim = 1024 diff --git a/t5x-main/t5x/contrib/gpu/t5/mt5/tiny.gin b/t5x-main/t5x/contrib/gpu/t5/mt5/tiny.gin deleted file mode 100644 index 04ad2c1ec22c0da73cde09cf9ea691e3bdf2fd63..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/t5/mt5/tiny.gin +++ /dev/null @@ -1,13 +0,0 @@ -# T5.1.1 tiny model. - -include 't5x/contrib/gpu/t5/t5_1_1/base.gin' # imports vocab, optimizer and model. - -# ------------------- Network specification overrides -------------------------- -network.Transformer.config = @network.T5Config() -network.T5Config: - emb_dim = 8 - num_heads = 4 - num_encoder_layers = 2 - num_decoder_layers = 2 - head_dim = 3 - mlp_dim = 16 diff --git a/t5x-main/t5x/contrib/gpu/t5/mt5/xl.gin b/t5x-main/t5x/contrib/gpu/t5/mt5/xl.gin deleted file mode 100644 index 18554580e0346e9185943a94aff4709fbd665a4b..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/t5/mt5/xl.gin +++ /dev/null @@ -1,13 +0,0 @@ -# mT5 XL model. - -include 't5x/contrib/gpu/t5/mt5/base.gin' # imports vocab, optimizer and model. - -# ------------------- Network specification overrides -------------------------- -network.Transformer.config = @network.T5Config() -network.T5Config: - emb_dim = 2048 - num_heads = 32 - num_encoder_layers = 24 - num_decoder_layers = 24 - head_dim = 64 - mlp_dim = 5120 diff --git a/t5x-main/t5x/contrib/gpu/t5/mt5/xxl.gin b/t5x-main/t5x/contrib/gpu/t5/mt5/xxl.gin deleted file mode 100644 index b27aeada1e1bc3bd15065f69d5a4d41492b718dd..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/t5/mt5/xxl.gin +++ /dev/null @@ -1,13 +0,0 @@ -# mT5 XXL model. - -include 't5x/contrib/gpu/t5/mt5/base.gin' # imports vocab, optimizer and model. - -# ------------------- Network specification overrides -------------------------- -network.Transformer.config = @network.T5Config() -network.T5Config: - emb_dim = 4096 - num_heads = 64 - num_encoder_layers = 24 - num_decoder_layers = 24 - head_dim = 64 - mlp_dim = 10240 diff --git a/t5x-main/t5x/contrib/gpu/t5/network.py b/t5x-main/t5x/contrib/gpu/t5/network.py deleted file mode 100644 index dd6107869cfd92bb8d999f18c2376fdbca570fef..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/t5/network.py +++ /dev/null @@ -1,429 +0,0 @@ -# Copyright 2022 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""T5.1.1 Transformer model.""" - -from typing import Any, Sequence - -from flax import linen as nn -from flax import struct -import jax.numpy as jnp -from t5x.contrib.gpu.t5 import layers - - -@struct.dataclass -class T5Config: - """Global hyperparameters used to minimize obnoxious kwarg plumbing.""" - vocab_size: int - # Activation dtypes. - dtype: Any = jnp.float32 - emb_dim: int = 512 - num_heads: int = 8 - num_encoder_layers: int = 6 - num_decoder_layers: int = 6 - head_dim: int = 64 - mlp_dim: int = 2048 - # Activation functions are retrieved from Flax. - mlp_activations: Sequence[str] = ('relu',) - dropout_rate: float = 0.1 - # If `True`, the embedding weights are used in the decoder output layer. - logits_via_embedding: bool = False - # Whether to accumulate attention logits in float32 regardless of dtype. - float32_attention_logits: bool = False - # Whether to scale attention logits by sqrt(d_k). Default to False for adafactor - scale_attn_logits: bool = False - - -class EncoderLayer(nn.Module): - """Transformer encoder layer.""" - config: T5Config - relative_embedding: nn.Module - - @nn.compact - def __call__(self, inputs, encoder_mask=None, deterministic=False): - cfg = self.config - - # Relative position embedding as attention biases. - encoder_bias = self.relative_embedding(inputs.shape[-2], inputs.shape[-2], - True) - - # Attention block. - assert inputs.ndim == 3 - x = layers.LayerNorm( - dtype=cfg.dtype, name='pre_attention_layer_norm')( - inputs) - # [batch, length, emb_dim] -> [batch, length, emb_dim] - x = layers.MultiHeadDotProductAttention( - num_heads=cfg.num_heads, - dtype=cfg.dtype, - head_dim=cfg.head_dim, - dropout_rate=cfg.dropout_rate, - float32_logits=cfg.float32_attention_logits, - name='attention', - scale_attn_logits=cfg.scale_attn_logits)( - x, x, encoder_mask, encoder_bias, deterministic=deterministic) - x = nn.Dropout( - rate=cfg.dropout_rate, broadcast_dims=(-2,))( - x, deterministic=deterministic) - x = x + inputs - - # MLP block. - y = layers.LayerNorm(dtype=cfg.dtype, name='pre_mlp_layer_norm')(x) - # [batch, length, emb_dim] -> [batch, length, emb_dim] - y = layers.MlpBlock( - intermediate_dim=cfg.mlp_dim, - activations=cfg.mlp_activations, - intermediate_dropout_rate=cfg.dropout_rate, - dtype=cfg.dtype, - name='mlp', - )(y, deterministic=deterministic) - y = nn.Dropout( - rate=cfg.dropout_rate, broadcast_dims=(-2,))( - y, deterministic=deterministic) - y = y + x - - return y - - -class DecoderLayer(nn.Module): - """Transformer decoder layer that attends to the encoder.""" - config: T5Config - relative_embedding: nn.Module - - @nn.compact - def __call__(self, - inputs, - encoded, - decoder_mask=None, - encoder_decoder_mask=None, - deterministic=False, - decode=False, - max_decode_length=None): - cfg = self.config - - # Relative position embedding as attention biases. - l = max_decode_length if decode and max_decode_length else inputs.shape[-2] - decoder_bias = self.relative_embedding(l, l, False) - - # inputs: embedded inputs to the decoder with shape [batch, length, emb_dim] - x = layers.LayerNorm( - dtype=cfg.dtype, name='pre_self_attention_layer_norm')( - inputs) - - # Self-attention block - x = layers.MultiHeadDotProductAttention( - num_heads=cfg.num_heads, - dtype=cfg.dtype, - head_dim=cfg.head_dim, - dropout_rate=cfg.dropout_rate, - float32_logits=cfg.float32_attention_logits, - name='self_attention', - scale_attn_logits=cfg.scale_attn_logits)( - x, - x, - decoder_mask, - decoder_bias, - deterministic=deterministic, - decode=decode) - x = nn.Dropout( - rate=cfg.dropout_rate, broadcast_dims=(-2,))( - x, deterministic=deterministic) - x = x + inputs - - # Encoder-Decoder block. - y = layers.LayerNorm( - dtype=cfg.dtype, name='pre_cross_attention_layer_norm')( - x) - y = layers.MultiHeadDotProductAttention( - num_heads=cfg.num_heads, - dtype=cfg.dtype, - head_dim=cfg.head_dim, - dropout_rate=cfg.dropout_rate, - float32_logits=cfg.float32_attention_logits, - name='encoder_decoder_attention', - scale_attn_logits=cfg.scale_attn_logits)( - y, encoded, encoder_decoder_mask, deterministic=deterministic) - y = nn.Dropout( - rate=cfg.dropout_rate, broadcast_dims=(-2,))( - y, deterministic=deterministic) - y = y + x - - # MLP block. - z = layers.LayerNorm(dtype=cfg.dtype, name='pre_mlp_layer_norm')(y) - z = layers.MlpBlock( - intermediate_dim=cfg.mlp_dim, - activations=cfg.mlp_activations, - intermediate_dropout_rate=cfg.dropout_rate, - dtype=cfg.dtype, - name='mlp', - )(z, deterministic=deterministic) - z = nn.Dropout( - rate=cfg.dropout_rate, broadcast_dims=(-2,))( - z, deterministic=deterministic) - z = z + y - - return z - - -class Encoder(nn.Module): - """A stack of encoder layers.""" - config: T5Config - shared_embedding: nn.Module - - @nn.compact - def __call__(self, - encoder_input_tokens, - encoder_mask=None, - deterministic=False): - cfg = self.config - assert encoder_input_tokens.ndim == 2 # [batch, length] - rel_emb = layers.RelativePositionBiases( - num_buckets=32, - max_distance=128, - num_heads=cfg.num_heads, - dtype=cfg.dtype, - embedding_init=nn.initializers.variance_scaling(1.0, 'fan_avg', - 'uniform'), - name='relpos_bias') - - # [batch, length] -> [batch, length, emb_dim] - x = self.shared_embedding(encoder_input_tokens.astype('int32')) - x = nn.Dropout( - rate=cfg.dropout_rate, broadcast_dims=(-2,))( - x, deterministic=deterministic) - x = x.astype(cfg.dtype) - - for lyr in range(cfg.num_encoder_layers): - # [batch, length, emb_dim] -> [batch, length, emb_dim] - x = EncoderLayer( - config=cfg, relative_embedding=rel_emb, - name=f'layers_{lyr}')(x, encoder_mask, deterministic) - - x = layers.LayerNorm(dtype=cfg.dtype, name='encoder_norm')(x) - return nn.Dropout(rate=cfg.dropout_rate)(x, deterministic=deterministic) - - -class Decoder(nn.Module): - """A stack of decoder layers as a part of an encoder-decoder architecture.""" - config: T5Config - shared_embedding: nn.Module - - @nn.compact - def __call__(self, - encoded, - decoder_input_tokens, - decoder_positions=None, - decoder_mask=None, - encoder_decoder_mask=None, - deterministic=False, - decode=False, - max_decode_length=None): - cfg = self.config - assert decoder_input_tokens.ndim == 2 # [batch, len] - rel_emb = layers.RelativePositionBiases( - num_buckets=32, - max_distance=128, - num_heads=cfg.num_heads, - dtype=cfg.dtype, - embedding_init=nn.initializers.variance_scaling(1.0, 'fan_avg', - 'uniform'), - name='relpos_bias') - - # [batch, length] -> [batch, length, emb_dim] - y = self.shared_embedding(decoder_input_tokens.astype('int32')) - y = nn.Dropout( - rate=cfg.dropout_rate, broadcast_dims=(-2,))( - y, deterministic=deterministic) - y = y.astype(cfg.dtype) - - for lyr in range(cfg.num_decoder_layers): - # [batch, length, emb_dim] -> [batch, length, emb_dim] - y = DecoderLayer( - config=cfg, relative_embedding=rel_emb, name=f'layers_{lyr}')( - y, - encoded, - decoder_mask=decoder_mask, - encoder_decoder_mask=encoder_decoder_mask, - deterministic=deterministic, - decode=decode, - max_decode_length=max_decode_length) - - y = layers.LayerNorm(dtype=cfg.dtype, name='decoder_norm')(y) - y = nn.Dropout( - rate=cfg.dropout_rate, broadcast_dims=(-2,))( - y, deterministic=deterministic) - - # [batch, length, emb_dim] -> [batch, length, vocab_size] - if cfg.logits_via_embedding: - # Use the transpose of embedding matrix for logit transform. - logits = self.shared_embedding.attend(y) - # Correctly normalize pre-softmax logits for this shared case. - logits = logits / jnp.sqrt(y.shape[-1]) - else: - logits = layers.DenseGeneral( - cfg.vocab_size, - dtype=jnp.float32, # Use float32 for stabiliity. - kernel_axes=('embed', 'vocab'), - name='logits_dense')( - y) - return logits - - -class Transformer(nn.Module): - """An encoder-decoder Transformer model.""" - config: T5Config - - def setup(self): - cfg = self.config - self.shared_embedding = layers.Embed( - num_embeddings=cfg.vocab_size, - features=cfg.emb_dim, - dtype=cfg.dtype, - attend_dtype=jnp.float32, # for logit training stability - embedding_init=nn.initializers.normal(stddev=1.0), - one_hot=False, - name='token_embedder') - - self.encoder = Encoder(config=cfg, shared_embedding=self.shared_embedding) - self.decoder = Decoder(config=cfg, shared_embedding=self.shared_embedding) - - def encode(self, - encoder_input_tokens, - encoder_segment_ids=None, - enable_dropout=True): - """Applies Transformer encoder-branch on the inputs.""" - cfg = self.config - assert encoder_input_tokens.ndim == 2 # (batch, len) - - # Make padding attention mask. - encoder_mask = layers.make_attention_mask( - encoder_input_tokens > 0, encoder_input_tokens > 0, dtype=cfg.dtype) - # Add segmentation block-diagonal attention mask if using segmented data. - if encoder_segment_ids is not None: - encoder_mask = layers.combine_masks( - encoder_mask, - layers.make_attention_mask( - encoder_segment_ids, - encoder_segment_ids, - jnp.equal, - dtype=cfg.dtype)) - - return self.encoder( - encoder_input_tokens, encoder_mask, deterministic=not enable_dropout) - - def decode( - self, - encoded, - encoder_input_tokens, # only needed for masks - decoder_input_tokens, - decoder_target_tokens, - encoder_segment_ids=None, - decoder_segment_ids=None, - decoder_positions=None, - enable_dropout=True, - decode=False, - max_decode_length=None): - """Applies Transformer decoder-branch on encoded-input and target.""" - cfg = self.config - - # Make padding attention masks. - if decode: - # Do not mask decoder attention based on targets padding at - # decoding/inference time. - decoder_mask = None - encoder_decoder_mask = layers.make_attention_mask( - jnp.ones_like(decoder_target_tokens), - encoder_input_tokens > 0, - dtype=cfg.dtype) - else: - decoder_mask = layers.make_decoder_mask( - decoder_target_tokens=decoder_target_tokens, - dtype=cfg.dtype, - decoder_segment_ids=decoder_segment_ids) - encoder_decoder_mask = layers.make_attention_mask( - decoder_target_tokens > 0, encoder_input_tokens > 0, dtype=cfg.dtype) - - # Add segmentation block-diagonal attention masks if using segmented data. - if encoder_segment_ids is not None: - if decode: - raise ValueError( - 'During decoding, packing should not be used but ' - '`encoder_segment_ids` was passed to `Transformer.decode`.') - - encoder_decoder_mask = layers.combine_masks( - encoder_decoder_mask, - layers.make_attention_mask( - decoder_segment_ids, - encoder_segment_ids, - jnp.equal, - dtype=cfg.dtype)) - - logits = self.decoder( - encoded, - decoder_input_tokens=decoder_input_tokens, - decoder_positions=decoder_positions, - decoder_mask=decoder_mask, - encoder_decoder_mask=encoder_decoder_mask, - deterministic=not enable_dropout, - decode=decode, - max_decode_length=max_decode_length) - return logits - - def __call__(self, - encoder_input_tokens, - decoder_input_tokens, - decoder_target_tokens, - encoder_segment_ids=None, - decoder_segment_ids=None, - encoder_positions=None, - decoder_positions=None, - *, - enable_dropout: bool = True, - decode: bool = False): - """Applies Transformer model on the inputs. - - This method requires both decoder_target_tokens and decoder_input_tokens, - which is a shifted version of the former. For a packed dataset, it usually - has additional processing applied. For example, the first element of each - sequence has id 0 instead of the shifted EOS id from the previous sequence. - - Args: - encoder_input_tokens: input data to the encoder. - decoder_input_tokens: input token to the decoder. - decoder_target_tokens: target token to the decoder. - encoder_segment_ids: encoder segmentation info for packed examples. - decoder_segment_ids: decoder segmentation info for packed examples. - encoder_positions: encoder subsequence positions for packed examples. - decoder_positions: decoder subsequence positions for packed examples. - enable_dropout: Ensables dropout if set to True. - decode: Whether to prepare and use an autoregressive cache. - - Returns: - logits array from full transformer. - """ - encoded = self.encode( - encoder_input_tokens, - encoder_segment_ids=encoder_segment_ids, - enable_dropout=enable_dropout) - - return self.decode( - encoded, - encoder_input_tokens, # only used for masks - decoder_input_tokens, - decoder_target_tokens, - encoder_segment_ids=encoder_segment_ids, - decoder_segment_ids=decoder_segment_ids, - decoder_positions=decoder_positions, - enable_dropout=enable_dropout, - decode=decode) diff --git a/t5x-main/t5x/contrib/gpu/t5/network_test.py b/t5x-main/t5x/contrib/gpu/t5/network_test.py deleted file mode 100644 index 075193528580f72eab02c3ed4363d8da1e9d3f56..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/t5/network_test.py +++ /dev/null @@ -1,111 +0,0 @@ -# Copyright 2022 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for network.""" - -import os - -from absl import flags -from absl.testing import absltest -from absl.testing import parameterized -import jax -import numpy as np -import seqio -from t5x import adafactor -from t5x import models -from t5x import test_utils -from t5x.contrib.gpu.t5 import network - -# Parse absl flags test_srcdir and test_tmpdir. -jax.config.parse_flags_with_absl() - -FLAGS = flags.FLAGS - - -def get_test_model(emb_dim, - head_dim, - num_heads, - mlp_dim, - dtype='float32', - vocab_size=32128, - num_encoder_layers=2, - num_decoder_layers=2): - config = network.T5Config( - num_encoder_layers=num_encoder_layers, - num_decoder_layers=num_decoder_layers, - vocab_size=vocab_size, - dropout_rate=0, - emb_dim=emb_dim, - num_heads=num_heads, - head_dim=head_dim, - mlp_dim=mlp_dim, - dtype=dtype, - mlp_activations=('gelu', 'linear')) - module = network.Transformer(config=config) - vocab = seqio.test_utils.sentencepiece_vocab() - optimizer_def = adafactor.Adafactor() - return models.EncoderDecoderModel( - module, vocab, vocab, optimizer_def=optimizer_def) - - -class NetworkTest(parameterized.TestCase): - - def setUp(self): - super().setUp() - batch_size, max_decode_len, input_len = 2, 3, 4 - self.input_shapes = { - 'encoder_input_tokens': (batch_size, input_len), - 'decoder_input_tokens': (batch_size, max_decode_len) - } - np.random.seed(42) - self.batch = { - 'encoder_input_tokens': - np.random.randint(3, 10, size=(batch_size, input_len)), - 'decoder_input_tokens': - np.random.randint(3, 10, size=(batch_size, max_decode_len)), - 'decoder_target_tokens': - np.random.randint(3, 10, size=(batch_size, max_decode_len)) - } - - def test_t5_1_1_regression(self): - np.random.seed(0) - batch_size, max_decode_len, input_len = 2, 3, 4 - batch = { - 'encoder_input_tokens': - np.random.randint(3, 10, size=(batch_size, input_len)), - 'decoder_input_tokens': - np.random.randint(3, 10, size=(batch_size, max_decode_len)), - 'decoder_target_tokens': - np.random.randint(3, 10, size=(batch_size, max_decode_len)) - } - model = get_test_model( - emb_dim=13, - head_dim=64, - num_heads=8, - mlp_dim=2048, - vocab_size=10, - num_encoder_layers=3) - params = model.get_initial_variables( - jax.random.PRNGKey(42), self.input_shapes)['params'] - loss, _ = jax.jit(model.loss_fn)(params, batch, jax.random.PRNGKey(1)) - self.assertAlmostEqual(loss, 18.088945, delta=0.05) - - predicted, scores = model.predict_batch_with_aux(params, batch) - np.testing.assert_array_equal(predicted, [[7, 1, 0], [1, 0, 0]]) - np.testing.assert_allclose( - scores['scores'], [-3.0401115, -1.9265753], rtol=1e-3) - - -if __name__ == '__main__': - absltest.main() diff --git a/t5x-main/t5x/contrib/gpu/t5/t5_1_0/11B.gin b/t5x-main/t5x/contrib/gpu/t5/t5_1_0/11B.gin deleted file mode 100644 index 086d67598badacfa060ca7e50e89159c45525f75..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/t5/t5_1_0/11B.gin +++ /dev/null @@ -1,13 +0,0 @@ -# T5.1.0 11B model. - -include 't5x/contrib/gpu/t5/t5_1_0/base.gin' # imports vocab, optimizer and model. - -# ------------------- Network specification overrides -------------------------- -network.Transformer.config = @network.T5Config() -network.T5Config: - emb_dim = 1024 - num_heads = 128 - num_encoder_layers = 24 - num_decoder_layers = 24 - head_dim = 128 - mlp_dim = 65536 diff --git a/t5x-main/t5x/contrib/gpu/t5/t5_1_0/3B.gin b/t5x-main/t5x/contrib/gpu/t5/t5_1_0/3B.gin deleted file mode 100644 index bb825b3ae83a0a7f501a4a9542d9f28762a9029b..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/t5/t5_1_0/3B.gin +++ /dev/null @@ -1,13 +0,0 @@ -# T5.1.0 3B model. - -include 't5x/contrib/gpu/t5/t5_1_0/base.gin' # imports vocab, optimizer and model. - -# ------------------- Network specification overrides -------------------------- -network.Transformer.config = @network.T5Config() -network.T5Config: - emb_dim = 1024 - num_heads = 32 - num_encoder_layers = 24 - num_decoder_layers = 24 - head_dim = 128 - mlp_dim = 16384 diff --git a/t5x-main/t5x/contrib/gpu/t5/t5_1_0/__init__.py b/t5x-main/t5x/contrib/gpu/t5/t5_1_0/__init__.py deleted file mode 100644 index da022c16301721a096a208e8bdb2a71bb87f9788..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/t5/t5_1_0/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright 2022 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This empty file is needed for loading the gin files in this directory. diff --git a/t5x-main/t5x/contrib/gpu/t5/t5_1_0/base.gin b/t5x-main/t5x/contrib/gpu/t5/t5_1_0/base.gin deleted file mode 100644 index e53fdaaf3555ba6d780620a6d1dcf6b9e212449a..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/t5/t5_1_0/base.gin +++ /dev/null @@ -1,55 +0,0 @@ -# T5.1.0 Base model. -from __gin__ import dynamic_registration - -import seqio -from t5x import adafactor -from t5x import models -from t5x.contrib.gpu.t5 import network - -# ------------------- Loss HParam ---------------------------------------------- -Z_LOSS = 0.0001 -LABEL_SMOOTHING = 0.0 -# NOTE: When fine-tuning the public T5 checkpoints (trained in T5 MeshTF) -# the loss normalizing factor should be set to pretraining batch_size * -# target_token_length. -LOSS_NORMALIZING_FACTOR = None -# Dropout should be specified in the "run" files -DROPOUT_RATE = %gin.REQUIRED - -# Vocabulary (shared by encoder and decoder) -VOCABULARY = @seqio.SentencePieceVocabulary() -seqio.SentencePieceVocabulary.sentencepiece_model_file = "gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model" - -# ------------------- Optimizer ------------------------------------------------ -# `learning_rate` is set by `Trainer.learning_rate_fn`. -OPTIMIZER = @adafactor.Adafactor() -adafactor.Adafactor: - decay_rate = 0.8 - step_offset = 0 - logical_factor_rules = @adafactor.standard_logical_factor_rules() - -# ------------------- Model ---------------------------------------------------- -MODEL = @models.EncoderDecoderModel() -models.EncoderDecoderModel: - module = @network.Transformer() - input_vocabulary = %VOCABULARY - output_vocabulary = %VOCABULARY - optimizer_def = %OPTIMIZER - z_loss = %Z_LOSS - label_smoothing = %LABEL_SMOOTHING - loss_normalizing_factor = %LOSS_NORMALIZING_FACTOR - -# ------------------- Network specification ------------------------------------ -network.Transformer.config = @network.T5Config() -network.T5Config: - vocab_size = 32128 # vocab size rounded to a multiple of 128 for TPU efficiency - dtype = 'bfloat16' - emb_dim = 768 - num_heads = 12 - num_encoder_layers = 12 - num_decoder_layers = 12 - head_dim = 64 - mlp_dim = 3072 - mlp_activations = ('relu',) - dropout_rate = %DROPOUT_RATE - logits_via_embedding = True diff --git a/t5x-main/t5x/contrib/gpu/t5/t5_1_0/large.gin b/t5x-main/t5x/contrib/gpu/t5/t5_1_0/large.gin deleted file mode 100644 index d022b530dc8a336734566f932a2b277396c3d68c..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/t5/t5_1_0/large.gin +++ /dev/null @@ -1,13 +0,0 @@ -# T5.1.0 Large model. - -include 't5x/contrib/gpu/t5/t5_1_0/base.gin' # imports vocab, optimizer and model. - -# ------------------- Network specification overrides -------------------------- -network.Transformer.config = @network.T5Config() -network.T5Config: - emb_dim = 1024 - num_heads = 16 - num_encoder_layers = 24 - num_decoder_layers = 24 - head_dim = 64 - mlp_dim = 4096 diff --git a/t5x-main/t5x/contrib/gpu/t5/t5_1_0/small.gin b/t5x-main/t5x/contrib/gpu/t5/t5_1_0/small.gin deleted file mode 100644 index 6eed7e8372d27c4319b9fcfa59e8ac1f3fa3b754..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/t5/t5_1_0/small.gin +++ /dev/null @@ -1,13 +0,0 @@ -# T5.1.1 Small model. - -include 't5x/contrib/gpu/t5/t5_1_0/base.gin' # imports vocab, optimizer and model. - -# ------------------- Network specification overrides -------------------------- -network.Transformer.config = @network.T5Config() -network.T5Config: - emb_dim = 512 - num_heads = 8 - num_encoder_layers = 6 - num_decoder_layers = 6 - head_dim = 64 - mlp_dim = 2048 diff --git a/t5x-main/t5x/contrib/gpu/t5/t5_1_0/tiny.gin b/t5x-main/t5x/contrib/gpu/t5/t5_1_0/tiny.gin deleted file mode 100644 index 04ad2c1ec22c0da73cde09cf9ea691e3bdf2fd63..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/t5/t5_1_0/tiny.gin +++ /dev/null @@ -1,13 +0,0 @@ -# T5.1.1 tiny model. - -include 't5x/contrib/gpu/t5/t5_1_1/base.gin' # imports vocab, optimizer and model. - -# ------------------- Network specification overrides -------------------------- -network.Transformer.config = @network.T5Config() -network.T5Config: - emb_dim = 8 - num_heads = 4 - num_encoder_layers = 2 - num_decoder_layers = 2 - head_dim = 3 - mlp_dim = 16 diff --git a/t5x-main/t5x/contrib/gpu/t5/t5_1_1/__init__.py b/t5x-main/t5x/contrib/gpu/t5/t5_1_1/__init__.py deleted file mode 100644 index da022c16301721a096a208e8bdb2a71bb87f9788..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/t5/t5_1_1/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright 2022 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This empty file is needed for loading the gin files in this directory. diff --git a/t5x-main/t5x/contrib/gpu/t5/t5_1_1/adamw_opt.gin b/t5x-main/t5x/contrib/gpu/t5/t5_1_1/adamw_opt.gin deleted file mode 100644 index b9973b8d66bfe0bf1158626b71093c66d08d3b0f..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/t5/t5_1_1/adamw_opt.gin +++ /dev/null @@ -1,19 +0,0 @@ -from __gin__ import dynamic_registration - -import optax -from t5x import optimizers -from t5x.contrib.gpu.t5 import network - -OPTIMIZER = @optimizers.chain() -optimizers.chain: - transformations = [@optax.clip_by_global_norm(), @optax.adamw()] - -optax.clip_by_global_norm: - max_norm = 1.0 - -optax.adamw: - learning_rate = 0.0001 - weight_decay = 0.01 - -network.T5Config: - scale_attn_logits = True diff --git a/t5x-main/t5x/contrib/gpu/t5/t5_1_1/base.gin b/t5x-main/t5x/contrib/gpu/t5/t5_1_1/base.gin deleted file mode 100644 index 1ba24e302b997150742cb1c450f3b53dcd9d5416..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/t5/t5_1_1/base.gin +++ /dev/null @@ -1,55 +0,0 @@ -# T5.1.1 Base model. -from __gin__ import dynamic_registration - -import seqio -from t5x import adafactor -from t5x import models -from t5x.contrib.gpu.t5 import network - -# ------------------- Loss HParam ---------------------------------------------- -Z_LOSS = 0.0001 -LABEL_SMOOTHING = 0.0 -# NOTE: When fine-tuning the public T5 checkpoints (trained in T5 MeshTF) -# the loss normalizing factor should be set to pretraining batch_size * -# target_token_length. -LOSS_NORMALIZING_FACTOR = None -# Dropout should be specified in the "run" files -DROPOUT_RATE = %gin.REQUIRED - -# Vocabulary (shared by encoder and decoder) -VOCABULARY = @seqio.SentencePieceVocabulary() -seqio.SentencePieceVocabulary.sentencepiece_model_file = "gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model" - -# ------------------- Optimizer ------------------------------------------------ -# `learning_rate` is set by `Trainer.learning_rate_fn`. -OPTIMIZER = @adafactor.Adafactor() -adafactor.Adafactor: - decay_rate = 0.8 - step_offset = 0 - logical_factor_rules = @adafactor.standard_logical_factor_rules() - -# ------------------- Model ---------------------------------------------------- -MODEL = @models.EncoderDecoderModel() -models.EncoderDecoderModel: - module = @network.Transformer() - input_vocabulary = %VOCABULARY - output_vocabulary = %VOCABULARY - optimizer_def = %OPTIMIZER - z_loss = %Z_LOSS - label_smoothing = %LABEL_SMOOTHING - loss_normalizing_factor = %LOSS_NORMALIZING_FACTOR - -# ------------------- Network specification ------------------------------------ -network.Transformer.config = @network.T5Config() -network.T5Config: - vocab_size = 32128 # vocab size rounded to a multiple of 128 for TPU efficiency - dtype = 'bfloat16' - emb_dim = 768 - num_heads = 12 - num_encoder_layers = 12 - num_decoder_layers = 12 - head_dim = 64 - mlp_dim = 2048 - mlp_activations = ('gelu', 'linear') - dropout_rate = %DROPOUT_RATE - logits_via_embedding = False diff --git a/t5x-main/t5x/contrib/gpu/t5/t5_1_1/examples/__init__.py b/t5x-main/t5x/contrib/gpu/t5/t5_1_1/examples/__init__.py deleted file mode 100644 index da022c16301721a096a208e8bdb2a71bb87f9788..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/t5/t5_1_1/examples/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright 2022 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This empty file is needed for loading the gin files in this directory. diff --git a/t5x-main/t5x/contrib/gpu/t5/t5_1_1/examples/base_c4_pretrain.gin b/t5x-main/t5x/contrib/gpu/t5/t5_1_1/examples/base_c4_pretrain.gin deleted file mode 100644 index a6c53a1d8983b2da158379f00343d97f7eea78ee..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/t5/t5_1_1/examples/base_c4_pretrain.gin +++ /dev/null @@ -1,19 +0,0 @@ -# Register necessary SeqIO Tasks/Mixtures. -from __gin__ import dynamic_registration -import t5.data.mixtures -import __main__ as train_script - - -include 't5x/contrib/gpu/t5/t5_1_1/base.gin' -include 't5x/contrib/gpu/t5/configs/runs/pretrain.gin' - - -MIXTURE_OR_TASK_NAME = "c4_v220_span_corruption" -TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 114} -TRAIN_STEPS = 100000 -DROPOUT_RATE = 0.0 -BATCH_SIZE = 256 - - -train_script.train: - eval_period = 2000 diff --git a/t5x-main/t5x/contrib/gpu/t5/t5_1_1/examples/base_wmt14enfr_eval.gin b/t5x-main/t5x/contrib/gpu/t5/t5_1_1/examples/base_wmt14enfr_eval.gin deleted file mode 100644 index fe41661d865ae9e0675c546d29fa12973cd8e1f9..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/t5/t5_1_1/examples/base_wmt14enfr_eval.gin +++ /dev/null @@ -1,46 +0,0 @@ -from __gin__ import dynamic_registration - -import __main__ as eval_script -import seqio -from t5.data import mixtures -from t5x import partitioning -from t5x import utils -from t5x import models - -include "t5x/contrib/gpu/t5/t5_1_1/base.gin" # defines %MODEL. - -INITIAL_CHECKPOINT_PATH = %gin.REQUIRED -EVAL_OUTPUT_DIR = %gin.REQUIRED - -DROPOUT_RATE = 0.0 # unused boilerplate - - -eval_script.evaluate: - model = %MODEL # imported from separate gin file - dataset_cfg = @utils.DatasetConfig() - partitioner = @partitioning.PjitPartitioner() - restore_checkpoint_cfg = @utils.RestoreCheckpointConfig() - output_dir = %EVAL_OUTPUT_DIR - inference_evaluator_cls = @seqio.Evaluator - - -seqio.Evaluator: - logger_cls = [@seqio.PyLoggingLogger, @seqio.TensorBoardLogger, @seqio.JSONLogger] - num_examples = None # Use all examples in the dataset. - use_memory_cache = True - - -utils.DatasetConfig: - mixture_or_task_name = %MIXTURE_OR_TASK_NAME - task_feature_lengths = None # Auto-computes the max feature lengths. - split = 'test' - batch_size = 32 - shuffle = False - seed = 42 - -partitioning.PjitPartitioner.num_partitions = 1 -models.EncoderDecoderModel.predict_batch_with_aux.num_decodes = 4 - -utils.RestoreCheckpointConfig: - path = %INITIAL_CHECKPOINT_PATH - mode = 'specific' diff --git a/t5x-main/t5x/contrib/gpu/t5/t5_1_1/examples/base_wmt14enfr_finetune.gin b/t5x-main/t5x/contrib/gpu/t5/t5_1_1/examples/base_wmt14enfr_finetune.gin deleted file mode 100644 index dc71d15f0a1f0094bb1479582337e432e8417e2b..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/t5/t5_1_1/examples/base_wmt14enfr_finetune.gin +++ /dev/null @@ -1,51 +0,0 @@ -from __gin__ import dynamic_registration - -import __main__ as train_script -import seqio -import t5.data.mixtures -from t5x import utils -from t5x import models - - -include 't5x/contrib/gpu/t5/configs/runs/finetune.gin' -include 't5x/contrib/gpu/t5/t5_1_1/base.gin' - -BATCH_SIZE = 128 -MIXTURE_OR_TASK_NAME = "wmt14_enfr_v003" -TASK_FEATURE_LENGTHS = {'inputs': 256, 'targets': 256} -DROPOUT_RATE = 0.1 -TRAIN_STEPS = 1_020_000 # 1000000 pre-trained steps + 20000 fine-tuning steps. - -INITIAL_CHECKPOINT_PATH = "gs://t5-data/pretrained_models/t5x/t5_1_1_base/checkpoint_1000000" - -# `LOSS_NORMALIZING_FACTOR`: When fine-tuning a model that was pre-trained -# using Mesh Tensorflow (e.g. the public T5 / mT5 / ByT5 models), this should be -# set to `pretraining batch_size` * `target_token_length`. For T5 and T5.1.1: -# `2048 * 114`. For mT5: `1024 * 229`. For ByT5: `1024 * 189`. -LOSS_NORMALIZING_FACTOR = 233472 - -train_script.train: - eval_period = 100 - -train_script.train: - train_dataset_cfg = @train/utils.DatasetConfig() - train_eval_dataset_cfg = @train_eval/utils.DatasetConfig() - infer_eval_dataset_cfg = @infer_eval/utils.DatasetConfig() - -models.EncoderDecoderModel.predict_batch_with_aux.num_decodes = 4 - -infer_eval/utils.DatasetConfig: - mixture_or_task_name = %MIXTURE_OR_TASK_NAME - task_feature_lengths = %TASK_FEATURE_LENGTHS - split = 'validation' - batch_size = 64 - shuffle = False - seed = 42 - use_cached = %USE_CACHED_TASKS - pack = False - module = %MIXTURE_OR_TASK_MODULE - -seqio.Evaluator: - logger_cls = [@seqio.PyLoggingLogger, @seqio.TensorBoardLogger, @seqio.JSONLogger] - num_examples = None # Use all examples in the dataset. - use_memory_cache = True diff --git a/t5x-main/t5x/contrib/gpu/t5/t5_1_1/examples/base_wmt14enfr_train.gin b/t5x-main/t5x/contrib/gpu/t5/t5_1_1/examples/base_wmt14enfr_train.gin deleted file mode 100644 index e805fc554c8ee1e07a4fa1e5c49328d6186d01a0..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/t5/t5_1_1/examples/base_wmt14enfr_train.gin +++ /dev/null @@ -1,18 +0,0 @@ -from __gin__ import dynamic_registration -import t5.data.mixtures -import __main__ as train_script -from t5x import utils - -include 't5x/contrib/gpu/t5/configs/runs/pretrain.gin' -include 't5x/contrib/gpu/t5/t5_1_1/base.gin' - -TRAIN_STEPS = 100000 -BATCH_SIZE = 128 -MIXTURE_OR_TASK_NAME = "wmt14_enfr_v003" -TASK_FEATURE_LENGTHS = {'inputs': 256, 'targets': 256} -DROPOUT_RATE = 0.1 - -train_script.train: - eval_period = 2000 -utils.SaveCheckpointConfig: - period = 200 diff --git a/t5x-main/t5x/contrib/gpu/t5/t5_1_1/examples/base_wmt19_ende_train.gin b/t5x-main/t5x/contrib/gpu/t5/t5_1_1/examples/base_wmt19_ende_train.gin deleted file mode 100644 index d2697b8f62b129c848f2f9ce11c2319954e40e6f..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/t5/t5_1_1/examples/base_wmt19_ende_train.gin +++ /dev/null @@ -1,62 +0,0 @@ -from __gin__ import dynamic_registration - -import __main__ as train_script -from t5x import adafactor -from t5x import models -from t5x import partitioning -from t5x import trainer -from t5x import utils -from t5x.contrib.gpu.t5 import network - -include "t5x/contrib/gpu/t5/t5_1_1/base.gin" -include "t5x/contrib/gpu/t5/configs/runs/finetune.gin" - -MIXTURE_OR_TASK_NAME = "wmt19_ende_v003" -MIXTURE_OR_TASK_MODULE = "t5.data.mixtures" -TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 512} -TRAIN_STEPS = 5000 -LABEL_SMOOTHING = 0.1 -INITIAL_CHECKPOINT_PATH = None -# Note that `DROPOUT_RATE = 0.1` is specified in the finetune.gin but we just -# repeat to make it explicit. -DROPOUT_RATE = 0.1 - -train/utils.DatasetConfig: - batch_size = 128 - use_cached = False - pack = True - use_custom_packing_ops = False - seed = 0 - -train_eval/utils.DatasetConfig: - batch_size = 128 - use_cached = False - pack = False - use_custom_packing_ops = False - seed = 0 - -infer_eval/utils.DatasetConfig: - use_cached = False - -train_script.train: - eval_period = 250 - eval_steps = 20 - random_seed = 0 - use_hardware_rng = True - -utils.CheckpointConfig.restore = None -utils.SaveCheckpointConfig: - period = 500 # checkpoint frequency - keep = 1 - -# Decoder overrides -models.EncoderDecoderModel.predict_batch_with_aux.num_decodes = 4 - -trainer.Trainer.num_microbatches = 2 -utils.create_learning_rate_scheduler.warmup_steps = 1000 - -partitioning.PjitPartitioner: - model_parallel_submesh = (1, 1, 1, 2) - -adafactor.Adafactor: - logical_factor_rules = @adafactor.standard_logical_factor_rules() diff --git a/t5x-main/t5x/contrib/gpu/t5/t5_1_1/examples/base_wmt_eval.gin b/t5x-main/t5x/contrib/gpu/t5/t5_1_1/examples/base_wmt_eval.gin deleted file mode 100644 index a2bcef7ab4d6c3c9f4c3ef27762522e1f959c287..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/t5/t5_1_1/examples/base_wmt_eval.gin +++ /dev/null @@ -1,34 +0,0 @@ -from __gin__ import dynamic_registration - -import __main__ as eval_script -from t5.data import mixtures -from t5x import partitioning -from t5x import utils - -include "t5x/contrib/gpu/t5/t5_1_1/base.gin" # defines %MODEL. - -CHECKPOINT_PATH = %gin.REQUIRED # passed via commandline -EVAL_OUTPUT_DIR = %gin.REQUIRED # passed via commandline - -DROPOUT_RATE = 0.0 # unused boilerplate -MIXTURE_OR_TASK_NAME = "wmt_t2t_ende_v003" - -eval_script.evaluate: - model = %MODEL # imported from separate gin file - dataset_cfg = @utils.DatasetConfig() - restore_checkpoint_cfg = @utils.RestoreCheckpointConfig() - output_dir = %EVAL_OUTPUT_DIR - -utils.DatasetConfig: - mixture_or_task_name = %MIXTURE_OR_TASK_NAME - task_feature_lengths = None # Auto-computes the max feature lengths. - split = 'test' - batch_size = 32 - shuffle = False - seed = 42 - -partitioning.PjitPartitioner.num_partitions = 2 - -utils.RestoreCheckpointConfig: - path = %CHECKPOINT_PATH - mode = 'specific' diff --git a/t5x-main/t5x/contrib/gpu/t5/t5_1_1/examples/base_wmt_from_scratch.gin b/t5x-main/t5x/contrib/gpu/t5/t5_1_1/examples/base_wmt_from_scratch.gin deleted file mode 100644 index f4798c66c732d83fb55901a5ddcccabc83c373d6..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/t5/t5_1_1/examples/base_wmt_from_scratch.gin +++ /dev/null @@ -1,63 +0,0 @@ -from __gin__ import dynamic_registration - -import __main__ as train_script -import seqio -from t5.data import mixtures -from t5x import models -from t5x import partitioning -from t5x import utils - -include "t5x/contrib/gpu/t5/t5_1_1/base.gin" -include "t5x/contrib/gpu/t5/configs/runs/pretrain.gin" - -MIXTURE_OR_TASK_NAME = "wmt_t2t_ende_v003" -TASK_FEATURE_LENGTHS = {"inputs": 256, "targets": 256} -TRAIN_STEPS = 50000 -DROPOUT_RATE = 0.0 - -train/utils.DatasetConfig: - batch_size = 128 - use_cached = False - pack = True - seed = 0 - -train_eval/utils.DatasetConfig: - batch_size = 128 - use_cached = False - pack = True - seed = 0 - -infer_eval/utils.DatasetConfig: - mixture_or_task_name = %MIXTURE_OR_TASK_NAME - task_feature_lengths = None # compute max - split = "validation" - seed = 0 - batch_size = 128 - shuffle = False - use_cached = False - -train_script.train: - eval_period = 500 - eval_steps = 20 - random_seed = 0 - use_hardware_rng = True - infer_eval_dataset_cfg = @infer_eval/utils.DatasetConfig() - inference_evaluator_cls = @seqio.Evaluator - -seqio.Evaluator: - logger_cls = [@seqio.PyLoggingLogger, @seqio.TensorBoardLogger, @seqio.JSONLogger] - num_examples = None # Use all examples in the infer_eval dataset. - use_memory_cache = True - -utils.SaveCheckpointConfig: - period = 5000 # checkpoint frequency - -# `num_decodes` is equivalent to a beam size in a beam search decoding. -models.EncoderDecoderModel.predict_batch_with_aux.num_decodes = 4 - -partitioning.PjitPartitioner.num_partitions = 2 - -utils.create_learning_rate_scheduler: - factors = 'constant * rsqrt_decay' - base_learning_rate = 1.0 - warmup_steps = 10000 diff --git a/t5x-main/t5x/contrib/gpu/t5/t5_1_1/examples/base_wmt_from_scratch_adamw.gin b/t5x-main/t5x/contrib/gpu/t5/t5_1_1/examples/base_wmt_from_scratch_adamw.gin deleted file mode 100644 index b735aaae119cbdff48018db38f8104d50e2fc56d..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/t5/t5_1_1/examples/base_wmt_from_scratch_adamw.gin +++ /dev/null @@ -1,51 +0,0 @@ -# This gin file is to show how to switch to an optimizer other than -# Adafactor. Gin configuration makes it easy by simply importing any available -# optimizer in t5x/optimizers module. Note the optimizers in t5x/optimizers are -# wrapped version of optimizers implemented in optax. - -from __gin__ import dynamic_registration - -from t5x import optimizers -from t5x import utils -import optax - -include "t5x/contrib/gpu/t5/t5_1_1/examples/base_wmt_from_scratch.gin" - -# In this case, we choose to switch to the AdamW optimizer with gradient clip. -OPTIMIZER = @optimizers.chain() - -optimizers.chain: - transformations = [@optax.clip(), @optax.adamw()] - -optax.clip: - max_delta = 1.0 - -optax.adamw: - # Unlike Adafactor, most optimizers require to specify - # `learning_rate`. `learning_rate` accepts a float number (e.g., 1e-4) or - # a schedule function, which should take an argument `step` and output - # a learning rate for that step. - # As for choices of schedule functions, we can either use T5x - # learning rate scheduler, i.e., utils.create_learning_rate_scheduler, or - # optax's native schedule functions, e.g., warmup_cosine_decay_schedule. - learning_rate = @optax.warmup_cosine_decay_schedule() - -optax.warmup_cosine_decay_schedule: - init_value = 0.0 - peak_value = 1e-4 - warmup_steps = 1000 - decay_steps = %TRAIN_STEPS - end_value = 0.0 - - -# Below is an example of using the T5X's schedule functions. -# Feel free to uncomment to try. -# optax.adamw: -# learning_rate = @utils.create_learning_rate_scheduler() - -# utils.create_learning_rate_scheduler: -# factors = 'constant * linear_warmup * rsqrt_decay' -# base_learning_rate = 0.01 -# warmup_steps = 10000 - - diff --git a/t5x-main/t5x/contrib/gpu/t5/t5_1_1/examples/base_wmt_infer.gin b/t5x-main/t5x/contrib/gpu/t5/t5_1_1/examples/base_wmt_infer.gin deleted file mode 100644 index 872c16322961b102f0337d9e3c0043969d0bc777..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/t5/t5_1_1/examples/base_wmt_infer.gin +++ /dev/null @@ -1,19 +0,0 @@ -from __gin__ import dynamic_registration - -import __main__ as infer_script -from t5.data import mixtures -from t5x import partitioning -from t5x import utils - -include "t5x/contrib/gpu/t5/t5_1_1/base.gin" -include "t5x/contrib/gpu/t5/configs/runs/infer.gin" - -DROPOUT_RATE = 0.0 # unused but needs to be specified -MIXTURE_OR_TASK_NAME = "wmt_t2t_ende_v003" -TASK_FEATURE_LENGTHS = {"inputs": 64, "targets": 64} - -partitioning.PjitPartitioner.num_partitions = 1 - -utils.DatasetConfig: - split = "test" - batch_size = 32 diff --git a/t5x-main/t5x/contrib/gpu/t5/t5_1_1/examples/large_mnli2_finetune_adam.gin b/t5x-main/t5x/contrib/gpu/t5/t5_1_1/examples/large_mnli2_finetune_adam.gin deleted file mode 100644 index 12e2ee6fa523fc194e36a8d84afeed0105022eb3..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/t5/t5_1_1/examples/large_mnli2_finetune_adam.gin +++ /dev/null @@ -1,23 +0,0 @@ -from __gin__ import dynamic_registration - -import __main__ as train_script -from t5.data import mixtures -from t5x import models -from t5x import partitioning -from t5x import utils -import t5x.contrib.gpu.scripts_gpu.seqio_tasks - -include "t5x/contrib/gpu/t5/t5_1_1/large.gin" -include 't5x/contrib/gpu/t5/t5_1_1/adamw_opt.gin' -include "t5x/contrib/gpu/t5/configs/runs/finetune_mnli.gin" - -MIXTURE_OR_TASK_NAME = "glue_mnli_v2" -TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 16} -TRAIN_STEPS = 1_006_001 # 1000000 pre-trained steps + 6000 fine-tuning steps. -DROPOUT_RATE = 0.1 -INITIAL_CHECKPOINT_PATH = "gs://t5-data/pretrained_models/t5x/t5_1_1_large/checkpoint_1000000" -# `LOSS_NORMALIZING_FACTOR`: When fine-tuning a model that was pre-trained -# using Mesh Tensorflow (e.g. the public T5 / mT5 / ByT5 models), this should be -# set to `pretraining batch_size` * `target_token_length`. For T5 and T5.1.1: -# `2048 * 114`. For mT5: `1024 * 229`. For ByT5: `1024 * 189`. -LOSS_NORMALIZING_FACTOR = 131077 diff --git a/t5x-main/t5x/contrib/gpu/t5/t5_1_1/examples/large_pile_pretrain.gin b/t5x-main/t5x/contrib/gpu/t5/t5_1_1/examples/large_pile_pretrain.gin deleted file mode 100644 index c3e6c312cc3af6b97962d54d7f00e29e21a8c6f1..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/t5/t5_1_1/examples/large_pile_pretrain.gin +++ /dev/null @@ -1,13 +0,0 @@ -include 't5x/contrib/gpu/t5/t5_1_1/large.gin' -include 't5x/contrib/gpu/t5/configs/runs/pretrain_pile.gin' -include 't5x/contrib/gpu/t5/t5_1_1/adamw_opt.gin' - -# Register necessary SeqIO Tasks/Mixtures -import t5.data.mixtures -import t5x.contrib.gpu.scripts_gpu.seqio_tasks - -MIXTURE_OR_TASK_NAME = "the_pile_span_corruption" -TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 128} -TRAIN_STEPS = 1000000 -DROPOUT_RATE = 0.0 -BATCH_SIZE = 2048 diff --git a/t5x-main/t5x/contrib/gpu/t5/t5_1_1/examples/large_squad1_finetune_adam.gin b/t5x-main/t5x/contrib/gpu/t5/t5_1_1/examples/large_squad1_finetune_adam.gin deleted file mode 100644 index 87b896f581922c8e040357ce9334ddf02d7fa527..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/t5/t5_1_1/examples/large_squad1_finetune_adam.gin +++ /dev/null @@ -1,24 +0,0 @@ -from __gin__ import dynamic_registration - -import __main__ as train_script -from t5.data import mixtures -from t5x import models -from t5x import partitioning -from t5x import utils -import t5x.contrib.gpu.scripts_gpu.seqio_tasks - -include "t5x/contrib/gpu/t5/t5_1_1/large.gin" -include 't5x/contrib/gpu/t5/t5_1_1/adamw_opt.gin' -include "t5x/contrib/gpu/t5/configs/runs/finetune_squad1.gin" - -MIXTURE_OR_TASK_NAME = "squad_v010_allanswers" -TASK_FEATURE_LENGTHS = {"inputs": 956, "targets": 256} -TRAIN_STEPS = 1_006_001 # 1000000 pre-trained steps + 6000 fine-tuning steps. -DROPOUT_RATE = 0.1 -INITIAL_CHECKPOINT_PATH = "gs://t5-data/pretrained_models/t5x/t5_1_1_large/checkpoint_1000000" -# `LOSS_NORMALIZING_FACTOR`: When fine-tuning a model that was pre-trained -# using Mesh Tensorflow (e.g. the public T5 / mT5 / ByT5 models), this should be -# set to `pretraining batch_size` * `target_token_length`. For T5 and T5.1.1: -# `2048 * 114`. For mT5: `1024 * 229`. For ByT5: `1024 * 189`. -# 2048 * 128 here -LOSS_NORMALIZING_FACTOR = 262144 diff --git a/t5x-main/t5x/contrib/gpu/t5/t5_1_1/examples/small_c4_pretrain.gin b/t5x-main/t5x/contrib/gpu/t5/t5_1_1/examples/small_c4_pretrain.gin deleted file mode 100644 index d3640a6dde29fa75fc8602ef6a5af03b1577a79b..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/t5/t5_1_1/examples/small_c4_pretrain.gin +++ /dev/null @@ -1,11 +0,0 @@ -include 't5x/contrib/gpu/t5/t5_1_1/small.gin' -include 't5x/contrib/gpu/t5/configs/runs/pretrain.gin' - -# Register necessary SeqIO Tasks/Mixtures. -import t5.data.mixtures - -MIXTURE_OR_TASK_NAME = "c4_v220_span_corruption" -TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 114} -TRAIN_STEPS = 10000 -DROPOUT_RATE = 0.0 -BATCH_SIZE = 256 diff --git a/t5x-main/t5x/contrib/gpu/t5/t5_1_1/examples/small_mnli2_finetune_adam.gin b/t5x-main/t5x/contrib/gpu/t5/t5_1_1/examples/small_mnli2_finetune_adam.gin deleted file mode 100644 index 7f600f12b517ada2ffd07997d37c82e16dd9cfd1..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/t5/t5_1_1/examples/small_mnli2_finetune_adam.gin +++ /dev/null @@ -1,24 +0,0 @@ -from __gin__ import dynamic_registration - -import __main__ as train_script -from t5.data import mixtures -from t5x import models -from t5x import partitioning -from t5x import utils -import t5x.contrib.gpu.scripts_gpu.seqio_tasks - -include "t5x/contrib/gpu/t5/t5_1_1/small.gin" -include 't5x/contrib/gpu/t5/t5_1_1/adamw_opt.gin' -include "t5x/contrib/gpu/t5/configs/runs/finetune_mnli.gin" - -MIXTURE_OR_TASK_NAME = "glue_mnli_v2" -TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 16} -TRAIN_STEPS = 1_015_001 # 1000000 pre-trained steps + 15000 fine-tuning steps. -DROPOUT_RATE = 0.1 -INITIAL_CHECKPOINT_PATH = "gs://t5-data/pretrained_models/t5x/t5_1_1_small/checkpoint_1000000" -# `LOSS_NORMALIZING_FACTOR`: When fine-tuning a model that was pre-trained -# using Mesh Tensorflow (e.g. the public T5 / mT5 / ByT5 models), this should be -# set to `pretraining batch_size` * `target_token_length`. For T5 and T5.1.1: -# `2048 * 114`. For mT5: `1024 * 229`. For ByT5: `1024 * 189`. -# 1024 * 128 here -LOSS_NORMALIZING_FACTOR = 131077 diff --git a/t5x-main/t5x/contrib/gpu/t5/t5_1_1/examples/small_pile_pretrain.gin b/t5x-main/t5x/contrib/gpu/t5/t5_1_1/examples/small_pile_pretrain.gin deleted file mode 100644 index 655b404c38397a3c45af3feb277c471707624b58..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/t5/t5_1_1/examples/small_pile_pretrain.gin +++ /dev/null @@ -1,13 +0,0 @@ -include 't5x/contrib/gpu/t5/t5_1_1/small.gin' -include 't5x/contrib/gpu/t5/configs/runs/pretrain_pile.gin' -include 't5x/contrib/gpu/t5/t5_1_1/adamw_opt.gin' - -# Register necessary SeqIO Tasks/Mixtures -import t5.data.mixtures -import t5x.contrib.gpu.scripts_gpu.seqio_tasks - -MIXTURE_OR_TASK_NAME = "the_pile_span_corruption" -TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 128} -TRAIN_STEPS = 1000000 -DROPOUT_RATE = 0.1 -BATCH_SIZE = 2048 diff --git a/t5x-main/t5x/contrib/gpu/t5/t5_1_1/examples/small_squad1_finetune_adam.gin b/t5x-main/t5x/contrib/gpu/t5/t5_1_1/examples/small_squad1_finetune_adam.gin deleted file mode 100644 index ba5af03d9c49c3949ec1b41ffcf67d920e5b17c5..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/t5/t5_1_1/examples/small_squad1_finetune_adam.gin +++ /dev/null @@ -1,24 +0,0 @@ -from __gin__ import dynamic_registration - -import __main__ as train_script -from t5.data import mixtures -from t5x import models -from t5x import partitioning -from t5x import utils -import t5x.contrib.gpu.scripts_gpu.seqio_tasks - -include "t5x/contrib/gpu/t5/t5_1_1/small.gin" -include 't5x/contrib/gpu/t5/t5_1_1/adamw_opt.gin' -include "t5x/contrib/gpu/t5/configs/runs/finetune_squad1.gin" - -MIXTURE_OR_TASK_NAME = "squad_v010_allanswers" -TASK_FEATURE_LENGTHS = {"inputs": 956, "targets": 256} -TRAIN_STEPS = 1_015_001 # 1000000 pre-trained steps + 15000 fine-tuning steps. -DROPOUT_RATE = 0.1 -INITIAL_CHECKPOINT_PATH = "gs://t5-data/pretrained_models/t5x/t5_1_1_small/checkpoint_1000000" -# `LOSS_NORMALIZING_FACTOR`: When fine-tuning a model that was pre-trained -# using Mesh Tensorflow (e.g. the public T5 / mT5 / ByT5 models), this should be -# set to `pretraining batch_size` * `target_token_length`. For T5 and T5.1.1: -# `2048 * 114`. For mT5: `1024 * 229`. For ByT5: `1024 * 189`. -# 2048 * 128 here -LOSS_NORMALIZING_FACTOR = 262144 diff --git a/t5x-main/t5x/contrib/gpu/t5/t5_1_1/examples/small_wmt_finetune.gin b/t5x-main/t5x/contrib/gpu/t5/t5_1_1/examples/small_wmt_finetune.gin deleted file mode 100644 index 4dc14db2093d030a9bec37caba3cacf7babe4d40..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/t5/t5_1_1/examples/small_wmt_finetune.gin +++ /dev/null @@ -1,21 +0,0 @@ -from __gin__ import dynamic_registration - -import __main__ as train_script -from t5.data import mixtures -from t5x import models -from t5x import partitioning -from t5x import utils - -include "t5x/contrib/gpu/t5/t5_1_1/small.gin" -include "t5x/contrib/gpu/t5/configs/runs/finetune.gin" - -MIXTURE_OR_TASK_NAME = "wmt_t2t_ende_v003" -TASK_FEATURE_LENGTHS = {"inputs": 256, "targets": 256} -TRAIN_STEPS = 1_020_000 # 1000000 pre-trained steps + 20000 fine-tuning steps. -DROPOUT_RATE = 0.0 -INITIAL_CHECKPOINT_PATH = "gs://t5-data/pretrained_models/t5x/t5_1_1_small/checkpoint_1000000" -# `LOSS_NORMALIZING_FACTOR`: When fine-tuning a model that was pre-trained -# using Mesh Tensorflow (e.g. the public T5 / mT5 / ByT5 models), this should be -# set to `pretraining batch_size` * `target_token_length`. For T5 and T5.1.1: -# `2048 * 114`. For mT5: `1024 * 229`. For ByT5: `1024 * 189`. -LOSS_NORMALIZING_FACTOR = 233472 diff --git a/t5x-main/t5x/contrib/gpu/t5/t5_1_1/examples/test_train_eval_t5_tiny.gin b/t5x-main/t5x/contrib/gpu/t5/t5_1_1/examples/test_train_eval_t5_tiny.gin deleted file mode 100644 index d3318e94f4a275fe0fd7c3998a64b2fd8b780b17..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/t5/t5_1_1/examples/test_train_eval_t5_tiny.gin +++ /dev/null @@ -1,13 +0,0 @@ -# Test config to exercise train.py, very similar to test_train_t5_tiny.gin, -# except this only does evaluation, no training. - -from __gin__ import dynamic_registration - -import __main__ as train_script - -include 't5x/contrib/gpu/t5/t5_1_1/examples/test_train_t5_tiny.gin' - -train_script.train: - run_eval_before_training = True - eval_period = 0 - total_steps = 0 diff --git a/t5x-main/t5x/contrib/gpu/t5/t5_1_1/examples/test_train_t5_tiny.gin b/t5x-main/t5x/contrib/gpu/t5/t5_1_1/examples/test_train_t5_tiny.gin deleted file mode 100644 index 6a959d27560d476b38bf8a107e940654f488c468..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/t5/t5_1_1/examples/test_train_t5_tiny.gin +++ /dev/null @@ -1,56 +0,0 @@ -# Test config to exercise train.py with model-based pjit partitioning. - -from __gin__ import dynamic_registration - -import __main__ as train_script -from t5x import adafactor -from t5x import models -from t5x import partitioning -from t5x import trainer -from t5x import utils - -include 't5x/contrib/gpu/t5/configs/runs/pretrain.gin' -include 't5x/contrib/gpu/t5/t5_1_1/tiny.gin' - -MODEL_DIR = "/tmp" # Will be overridden in test. - -TRAIN_STEPS = 3 -MIXTURE_OR_TASK_MODULE = "t5.data.mixtures" -MIXTURE_OR_TASK_NAME = "wmt19_ende_v003" -TASK_FEATURE_LENGTHS = {"inputs": 32, "targets": 32} -DROPOUT_RATE = 0.0 - -models.EncoderDecoderModel: - z_loss = 0.0 - label_smoothing = 0.0 - loss_normalizing_factor = None - - -train/utils.DatasetConfig: - pack = False - seed = 0 - shuffle = False - use_cached = False - batch_size = 8 - -train_eval/utils.DatasetConfig: - pack = False - seed = 0 - shuffle = False - use_cached = False - batch_size = 8 - -train_script.train: - random_seed = 0 - eval_steps = 2 - actions={'TRAIN_EVAL': [@trainer.TerminateOnNanAction()]} - -trainer.TerminateOnNanAction: - task = %MIXTURE_OR_TASK_NAME - -partitioning.PjitPartitioner.num_partitions = 2 -utils.SaveCheckpointConfig.period = 4 - -# Overriding from pretrain.gin to keep magic constants in tests. -utils.create_learning_rate_scheduler: - warmup_steps = 1000 diff --git a/t5x-main/t5x/contrib/gpu/t5/t5_1_1/examples/xl_mnli2_finetune_adam.gin b/t5x-main/t5x/contrib/gpu/t5/t5_1_1/examples/xl_mnli2_finetune_adam.gin deleted file mode 100644 index 7a585622572b1cc90c715b60ce5b946963ff6424..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/t5/t5_1_1/examples/xl_mnli2_finetune_adam.gin +++ /dev/null @@ -1,24 +0,0 @@ -from __gin__ import dynamic_registration - -import __main__ as train_script -from t5.data import mixtures -from t5x import models -from t5x import partitioning -from t5x import utils -import t5x.contrib.gpu.scripts_gpu.seqio_tasks - -include "t5x/contrib/gpu/t5/t5_1_1/xl.gin" -include 't5x/contrib/gpu/t5/t5_1_1/adamw_opt.gin' -include "t5x/contrib/gpu/t5/configs/runs/finetune_mnli.gin" - -MIXTURE_OR_TASK_NAME = "glue_mnli_v2" -TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 16} -TRAIN_STEPS = 1_006_001 # 1000000 pre-trained steps + 6000 fine-tuning steps. -DROPOUT_RATE = 0.1 -INITIAL_CHECKPOINT_PATH = "gs://t5-data/pretrained_models/t5x/t5_1_1_xl/checkpoint_1000000" -# `LOSS_NORMALIZING_FACTOR`: When fine-tuning a model that was pre-trained -# using Mesh Tensorflow (e.g. the public T5 / mT5 / ByT5 models), this should be -# set to `pretraining batch_size` * `target_token_length`. For T5 and T5.1.1: -# `2048 * 114`. For mT5: `1024 * 229`. For ByT5: `1024 * 189`. -# 1024 * 128 here -LOSS_NORMALIZING_FACTOR = 131077 diff --git a/t5x-main/t5x/contrib/gpu/t5/t5_1_1/examples/xl_pile_pretrain.gin b/t5x-main/t5x/contrib/gpu/t5/t5_1_1/examples/xl_pile_pretrain.gin deleted file mode 100644 index 80b6e20f7391cd3d3741bb6e92f44b83f7c25682..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/t5/t5_1_1/examples/xl_pile_pretrain.gin +++ /dev/null @@ -1,13 +0,0 @@ -include 't5x/contrib/gpu/t5/t5_1_1/xl.gin' -include 't5x/contrib/gpu/t5/configs/runs/pretrain_pile.gin' -include 't5x/contrib/gpu/t5/t5_1_1/adamw_opt.gin' - -# Register necessary SeqIO Tasks/Mixtures -import t5.data.mixtures -import t5x.contrib.gpu.scripts_gpu.seqio_tasks - -MIXTURE_OR_TASK_NAME = "the_pile_span_corruption" -TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 128} -TRAIN_STEPS = 1000000 -DROPOUT_RATE = 0.0 -BATCH_SIZE = 2048 diff --git a/t5x-main/t5x/contrib/gpu/t5/t5_1_1/examples/xl_squad1_finetune_adam.gin b/t5x-main/t5x/contrib/gpu/t5/t5_1_1/examples/xl_squad1_finetune_adam.gin deleted file mode 100644 index 6a4c7b2898eebde54da9d820e68b5d77ae36bee4..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/t5/t5_1_1/examples/xl_squad1_finetune_adam.gin +++ /dev/null @@ -1,24 +0,0 @@ -from __gin__ import dynamic_registration - -import __main__ as train_script -from t5.data import mixtures -from t5x import models -from t5x import partitioning -from t5x import utils -import t5x.contrib.gpu.scripts_gpu.seqio_tasks - -include "t5x/contrib/gpu/t5/t5_1_1/xl.gin" -include 't5x/contrib/gpu/t5/t5_1_1/adamw_opt.gin' -include "t5x/contrib/gpu/t5/configs/runs/finetune_squad1.gin" - -MIXTURE_OR_TASK_NAME = "squad_v010_allanswers" -TASK_FEATURE_LENGTHS = {"inputs": 956, "targets": 256} -TRAIN_STEPS = 1_006_001 # 1000000 pre-trained steps + 6000 fine-tuning steps. -DROPOUT_RATE = 0.1 -INITIAL_CHECKPOINT_PATH = "gs://t5-data/pretrained_models/t5x/t5_1_1_xl/checkpoint_1000000" -# `LOSS_NORMALIZING_FACTOR`: When fine-tuning a model that was pre-trained -# using Mesh Tensorflow (e.g. the public T5 / mT5 / ByT5 models), this should be -# set to `pretraining batch_size` * `target_token_length`. For T5 and T5.1.1: -# `2048 * 114`. For mT5: `1024 * 229`. For ByT5: `1024 * 189`. -# 2048 * 128 here -LOSS_NORMALIZING_FACTOR = 262144 diff --git a/t5x-main/t5x/contrib/gpu/t5/t5_1_1/examples/xxl_pile_pretrain.gin b/t5x-main/t5x/contrib/gpu/t5/t5_1_1/examples/xxl_pile_pretrain.gin deleted file mode 100644 index 565d14244d7c395a968d43c8f4b6697713069af2..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/t5/t5_1_1/examples/xxl_pile_pretrain.gin +++ /dev/null @@ -1,13 +0,0 @@ -include 't5x/contrib/gpu/t5/t5_1_1/xxl.gin' -include 't5x/contrib/gpu/t5/configs/runs/pretrain_pile.gin' -include 't5x/contrib/gpu/t5/t5_1_1/adamw_opt.gin' - -# Register necessary SeqIO Tasks/Mixtures -import t5.data.mixtures -import t5x.contrib.gpu.scripts_gpu.seqio_tasks - -MIXTURE_OR_TASK_NAME = "the_pile_span_corruption" -TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 128} -TRAIN_STEPS = 1000000 -DROPOUT_RATE = 0.0 -BATCH_SIZE = 2304 diff --git a/t5x-main/t5x/contrib/gpu/t5/t5_1_1/large.gin b/t5x-main/t5x/contrib/gpu/t5/t5_1_1/large.gin deleted file mode 100644 index 2485a990e983c07721e1b938f9930a1d5f148d49..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/t5/t5_1_1/large.gin +++ /dev/null @@ -1,13 +0,0 @@ -# T5.1.1 Large model. - -include 't5x/contrib/gpu/t5/t5_1_1/base.gin' # imports vocab, optimizer and model. - -# ------------------- Network specification overrides -------------------------- -network.Transformer.config = @network.T5Config() -network.T5Config: - emb_dim = 1024 - num_heads = 16 - num_encoder_layers = 24 - num_decoder_layers = 24 - head_dim = 64 - mlp_dim = 2816 diff --git a/t5x-main/t5x/contrib/gpu/t5/t5_1_1/small.gin b/t5x-main/t5x/contrib/gpu/t5/t5_1_1/small.gin deleted file mode 100644 index fef09fa3d4ce972784155f2fea24587868cf2477..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/t5/t5_1_1/small.gin +++ /dev/null @@ -1,13 +0,0 @@ -# T5.1.1 Small model. - -include 't5x/contrib/gpu/t5/t5_1_1/base.gin' # imports vocab, optimizer and model. - -# ------------------- Network specification overrides -------------------------- -network.Transformer.config = @network.T5Config() -network.T5Config: - emb_dim = 512 - num_heads = 6 - num_encoder_layers = 8 - num_decoder_layers = 8 - head_dim = 64 - mlp_dim = 1024 diff --git a/t5x-main/t5x/contrib/gpu/t5/t5_1_1/tiny.gin b/t5x-main/t5x/contrib/gpu/t5/t5_1_1/tiny.gin deleted file mode 100644 index 04ad2c1ec22c0da73cde09cf9ea691e3bdf2fd63..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/t5/t5_1_1/tiny.gin +++ /dev/null @@ -1,13 +0,0 @@ -# T5.1.1 tiny model. - -include 't5x/contrib/gpu/t5/t5_1_1/base.gin' # imports vocab, optimizer and model. - -# ------------------- Network specification overrides -------------------------- -network.Transformer.config = @network.T5Config() -network.T5Config: - emb_dim = 8 - num_heads = 4 - num_encoder_layers = 2 - num_decoder_layers = 2 - head_dim = 3 - mlp_dim = 16 diff --git a/t5x-main/t5x/contrib/gpu/t5/t5_1_1/xl.gin b/t5x-main/t5x/contrib/gpu/t5/t5_1_1/xl.gin deleted file mode 100644 index c36dff5379041f87a287f26c3d730fcce73290ca..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/t5/t5_1_1/xl.gin +++ /dev/null @@ -1,14 +0,0 @@ -# T5.1.1 XL model. - -include 't5x/contrib/gpu/t5/t5_1_1/base.gin' # imports vocab, optimizer and model. - -# ------------------- Network specification overrides -------------------------- -network.Transformer.config = @network.T5Config() -network.T5Config: - emb_dim = 2048 - num_heads = 32 - num_encoder_layers = 24 - num_decoder_layers = 24 - head_dim = 64 - mlp_dim = 5120 - diff --git a/t5x-main/t5x/contrib/gpu/t5/t5_1_1/xxl.gin b/t5x-main/t5x/contrib/gpu/t5/t5_1_1/xxl.gin deleted file mode 100644 index cd6e18786aceaf3761660cb0a898d26f137d9b54..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/gpu/t5/t5_1_1/xxl.gin +++ /dev/null @@ -1,13 +0,0 @@ -# T5.1.1 XXL model. - -include 't5x/contrib/gpu/t5/t5_1_1/base.gin' # imports vocab, optimizer and model. - -# ------------------- Network specification overrides -------------------------- -network.Transformer.config = @network.T5Config() -network.T5Config: - emb_dim = 4096 - num_heads = 64 - num_encoder_layers = 24 - num_decoder_layers = 24 - head_dim = 64 - mlp_dim = 10240 diff --git a/t5x-main/t5x/contrib/moe/README.md b/t5x-main/t5x/contrib/moe/README.md deleted file mode 100644 index 6da7f167af44ef2ee8b7c663279cc881d4c28014..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/moe/README.md +++ /dev/null @@ -1,46 +0,0 @@ -# Mixture of Experts - - -This repo contains overrides and configs for training sparse Mixture of Experts -(MoE) models with T5X. The existing setups and examples all use [Flaxformer](https://github.com/google/flaxformer). - -## Training standard MoE architectures - -If you are looking train a T5X variant of a popular Mesh Tensorflow MoE model -(e.g. [Switch Transformer](https://arxiv.org/abs/2101.03961) or [Sparsely-Gated Mixture-of-Experts](https://arxiv.org/abs/1701.06538)) or adapt existing -MoE models, then the easiest way to get started is to plug one of the -[(Flaxformer) model gin configs](https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/moe/models) -into the [T5X Quickstart guide](https://github.com/google-research/t5x). To customize the default MoE models, you can override aspects of the underlying [(Flaxformer) architecture gin config](https://github.com/google/flaxformer/blob/main/flaxformer/t5x/configs/moe/architectures/moe.gin). - -## Using MoE in your existing model - -Alternatively, if you already have your own existing T5X/Flaxformer model -architecture and wish to add MoE layers, you can directly use the -[Flaxformer MoeLayer](https://github.com/google/flaxformer/blob/b725bd2a51d70e866d819c92de166fbf24425e6a/flaxformer/architectures/moe/moe_layers.py#L67). -Currently, the MoeLayer is constrained to use -[Flaxformer MlpBlock(s)](https://github.com/google/flaxformer/blob/b725bd2a51d70e866d819c92de166fbf24425e6a/flaxformer/components/dense.py#L185) -as experts. As a point of reference: MoeLayer(s) are integrated with the Flaxformer T5 -architecture through the -[SparseEncoder](https://github.com/google/flaxformer/blob/b725bd2a51d70e866d819c92de166fbf24425e6a/flaxformer/architectures/moe/moe_architecture.py#L36) -and -[SparseDecoder](https://github.com/google/flaxformer/blob/b725bd2a51d70e866d819c92de166fbf24425e6a/flaxformer/architectures/moe/moe_architecture.py#L162). -These classes allow us to interleave sparse MoE and dense MLP blocks through the -`sparse_layout` attribute. - -## Expert routing mechanisms - -A number of routing mechanisms are supported: - -* Switch routing (or top-1 "tokens choose" routing) based on the - [Switch Transformer](https://arxiv.org/abs/2101.03961) -* General Top-k "tokens choose" routing of the form used in - [Sparsely-Gated Mixture-of-Experts](https://arxiv.org/abs/1701.06538), - [Vision MoE](https://arxiv.org/abs/2106.05974), - [Designing Effective Sparse Expert Models](https://arxiv.org/abs/2202.08906) - and many other MoE works -* "Experts choose" routing introduced in - [Mixture-of-Experts with Expert Choice Routing](https://arxiv.org/abs/2202.09368) - -See the -[Flaxformer router codebase](https://github.com/google/flaxformer/blob/main/flaxformer/architectures/moe/routing.py) for details. - diff --git a/t5x-main/t5x/contrib/moe/__init__.py b/t5x-main/t5x/contrib/moe/__init__.py deleted file mode 100644 index 184349df14c5b61f2b8c353d58984561b0bd31e7..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/moe/__init__.py +++ /dev/null @@ -1,23 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Import API modules.""" - -import t5x.contrib.moe.adafactor_utils -import t5x.contrib.moe.models -import t5x.contrib.moe.partitioning -import t5x.contrib.moe.trainer -import t5x.contrib.moe.training_utils -# Version number. -from t5x.version import __version__ diff --git a/t5x-main/t5x/contrib/moe/adafactor_utils.py b/t5x-main/t5x/contrib/moe/adafactor_utils.py deleted file mode 100644 index 7af4eed722cb5e9c7672634c27a08e1d1e7a221f..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/moe/adafactor_utils.py +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Adafactor logical rules for Mixture of Experts models.""" - -from flax import core as flax_core -from t5x import adafactor - -FactorDim = adafactor.FactorDim -FrozenDict = flax_core.FrozenDict - - -def logical_factor_rules() -> FrozenDict: - """Logical factor rules for Mixture of Experts.""" - rules = flax_core.unfreeze(adafactor.standard_logical_factor_rules()) - rules.update({ - 'expert': FactorDim.BATCH, - 'expert_mlp': FactorDim.COLUMN, - 'unmodeled': FactorDim.NONE, - 'mlp_embed': FactorDim.ROW, # Same factoring as 'embed' - }) - return flax_core.freeze(rules) diff --git a/t5x-main/t5x/contrib/moe/checkpoints.py b/t5x-main/t5x/contrib/moe/checkpoints.py deleted file mode 100644 index 4233b853bda07d99581878510a8abbf00aeefa0c..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/moe/checkpoints.py +++ /dev/null @@ -1,320 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Mixture-of-Experts checkpoint overrides.""" - -import os -from typing import Any, Optional, Union - -import clu.data -import jax -from jax.experimental.array_serialization import serialization as array_serialization -from jax.experimental.pjit import pjit -import jax.numpy as jnp -import numpy as np -from t5x import checkpoint_importer -from t5x import checkpoints -from t5x import partitioning -from t5x import train_state as train_state_lib -import tensorflow as tf -import tensorstore as ts - -LazyAwaitableArray = checkpoint_importer.LazyAwaitableArray -_ParameterInfo = checkpoints._ParameterInfo # pylint: disable=protected-access -PartitionSpec = partitioning.PartitionSpec - - -class UpcycleCheckpointer(checkpoints.Checkpointer): - """Modified Checkpointer for sparse upcycling (dense-to-sparse) runs. - - This subclass calls modified _read_ts, namely _read_upcycle_ts, which - broadcasts the checkpoint's dense MLP weights to the model's sparse, expert - weights. This enables sparsifying dense checkpoints. See also _read_upcycle_ts - for more details. - """ - - def __init__( # pytype: disable=annotation-type-mismatch # jnp-type - self, - train_state: train_state_lib.TrainState, - partitioner: partitioning.BasePartitioner, - checkpoints_dir: str, - num_experts: int, - dataset_iterator: Optional[ - Union[tf.data.Iterator, clu.data.dataset_iterator.DatasetIterator] - ] = None, - *, - keep: Optional[int] = None, - save_dtype: jnp.dtype = np.float32, - restore_dtype: Optional[jnp.dtype] = None, - keep_dataset_checkpoints: Optional[int] = None, - ): - """Checkpointer constructor. - - Args: - train_state: A train state to be used to determine the structure of the - parameter tree, and the *full* (non-partitioned) parameter shapes and - dtypes. Saved and restored train states must match this structure. - partitioner: The partitioner to use for determining the local chunks - mapping or to perform params partitioning on restore. - checkpoints_dir: a path to a directory to save checkpoints in and restore - them from. - num_experts: Global number of experts. - dataset_iterator: An optional iterator to save/restore. - keep: An optional maximum number of checkpoints to keep. If more than this - number of checkpoints exist after a save, the oldest ones will be - automatically deleted to save space. - save_dtype: Dtype to cast targets to before saving. - restore_dtype: Optional dtype to cast targets to after restoring. If None, - no parameter casting is performed. - keep_dataset_checkpoints: An optional maximum number of data iterators to - keep. If more than this number of data iterators exist after a save, the - oldest ones will be automatically deleted to save space. - """ - super().__init__( - train_state=train_state, - partitioner=partitioner, - checkpoints_dir=checkpoints_dir, - dataset_iterator=dataset_iterator, - keep=keep, - save_dtype=save_dtype, - restore_dtype=restore_dtype, - keep_dataset_checkpoints=keep_dataset_checkpoints, - ) - - self._num_experts = num_experts - - def _create_lazy_awaitable_array( - self, - param_info: _ParameterInfo, - maybe_ts_spec: Any, - ckpt_path: str, - restore_dtype: Optional[jnp.dtype], - ) -> LazyAwaitableArray: - """Creates LazyArray from tensorstore and optionally broadcasts it. - - Does not materialize the array immediately. - - The only difference of this method from that of the parent class is that - this one calls _read_upcycle_ts instead of _read_ts, which also performs - broadcasting the MoE weights and optimizer states for sparsely upcycled - models. - - Args: - param_info: Information about how to read the parameter, host based sliced - reads and the like. - maybe_ts_spec: The tensorstore spec to read the parameter or some other - object. If this is an array then we will do a host based sliced read on - it (provided the param_info says to). Anything else we just return. - ckpt_path: A base location to use when resolving the relative paths in the - tensorstore spec. - restore_dtype: Type to restore as. None indicates that no cast is - requested. - - Returns: - LazyArray object. If it is an MLP parameter kernel that needs to be - "sparsified", then the MLP parameter kernel is broadcast to all experts. - """ - mesh = self._partitioner.mesh - axes = param_info.axes - - async def get_fn(): - nonlocal mesh - nonlocal axes - arr = await _read_upcycle_ts( - param_info, - maybe_ts_spec, - ckpt_path, - self._num_experts, - restore_dtype=restore_dtype, - mesh=mesh, - axes=axes, - ) - - is_sharded_jax_array = ( - isinstance(arr, jax.Array) and not arr.is_fully_addressable - ) - if ( - isinstance(arr, (np.ndarray, jnp.ndarray)) - and not is_sharded_jax_array - ): - if axes is None: - axes = PartitionSpec( - None, - ) - if restore_dtype is not None: - arr = arr.astype(restore_dtype) - arr = jax.make_array_from_callback( - arr.shape, - jax.sharding.NamedSharding(mesh, axes), - lambda idx: arr[idx], - ) - return arr - - return LazyAwaitableArray.from_tensor_store_spec_or_array( - maybe_ts_spec, get_fn, dtype=restore_dtype - ) - - -async def _read_upcycle_ts( - param_info: _ParameterInfo, - maybe_ts_spec: Any, - ckpt_path: str, - num_experts: int, - restore_dtype: Optional[jnp.dtype] = None, - mesh: Optional[jax.sharding.Mesh] = None, - axes: Optional[jax.sharding.PartitionSpec] = None, -): - """Reads array from tensorstore and handles broadcasting of expert weights. - - If both `mesh` and `axes` are provided, the method will attempt to restore the - array as a GlobalDeviceArray. - - This method is adapted from _read_ts() in t5x/checkpoints.py. This variant - broadcasts dense MLP weights from the checkpoint to the sparse, expert weights - of the model. - - Args: - param_info: Information about how to read the parameter, host based sliced - reads and the like. - maybe_ts_spec: The tensorstore spec to read the parameter or some other - object. If this is an array then we will do a host based sliced read on it - (provided the param_info says to). Anything else we just return. - ckpt_path: A base location to use when resolving the relative paths in the - tensorstore spec. - num_experts: Global number of experts. - restore_dtype: type to restore as. None indicates that no cast is requested. - mesh: Mesh object for GDA restoration. - axes: MeshAxes object for GDA restoration. - - Returns: - The array. Depending on the value `maybe_ts_spec` it might be read from - tensorstore, or it might be returned as is. Depending on the values in - param_info (specifically the `local_chunk_info`) it might be the full value - or a specific slice. If it is an expert parameter, then it is broadcast to - all experts. - """ - if param_info: - param_name = param_info.name - m_or_v = param_name.endswith('/m') or param_name.endswith('/v') - is_expert_param = 'expert/' in param_name - - # If saved as a numpy array, but a partitioned read is requested, return a - # slice of the array for that host. Otherwise, return the whole thing. - if isinstance(maybe_ts_spec, np.ndarray) and param_info: - if mesh is not None and axes is not None: - # Using GDA, return global array without selecting local chunk. - return maybe_ts_spec - elif param_info.local_chunk_info: - return maybe_ts_spec[param_info.local_chunk_info.slice] - else: - return maybe_ts_spec - # If we have anything else that isn't a tensorstore spec just return it. - elif not isinstance(maybe_ts_spec, ts.Spec): - return maybe_ts_spec - - tmp_ts_spec_dict = maybe_ts_spec.to_json() - # Remove non-required params so that we can open Tensorstore - # that was created with a different set of params. - del tmp_ts_spec_dict['metadata']['chunks'] - del tmp_ts_spec_dict['metadata']['compressor'] - - # Convert the relative path in the spec to a path based on the checkpoint - # location. Path and gcs bucket (if applicable) information is updated - # in-place. - checkpoints._update_ts_path_from_relative_to_absolute( # pylint:disable=protected-access - os.path.dirname(ckpt_path), tmp_ts_spec_dict - ) - - if param_info.shape is not None: - ts_spec_arr_shape = tuple(tmp_ts_spec_dict['metadata']['shape']) - # Check that the shapes of the array on disk match the expected shape based - # on the optimizer that is being restored. - if (not m_or_v) and is_expert_param: - shapes_match = ts_spec_arr_shape == param_info.shape[1:] - else: - shapes_match = ts_spec_arr_shape == param_info.shape - if not shapes_match: - raise ValueError( - f'Shape of `{param_info.name}` in checkpoint ' - f'{ts_spec_arr_shape} does not match expected ' - f'{param_info.shape}.' - ) - - if ( - 'dtype' in tmp_ts_spec_dict and tmp_ts_spec_dict['dtype'] == 'uint16' - ) or ( - 'dtype' in tmp_ts_spec_dict['metadata'] - and tmp_ts_spec_dict['metadata']['dtype'] == ' jax.Array: - """Converts NumPy array into sharded JAX Array.""" - return jax.make_array_from_callback( - arr.shape, - jax.sharding.NamedSharding(global_mesh, mesh_axes), - lambda idx: arr[idx], - ) - - -def make_train_state( - *, - step: Optional[int], - params: Mapping[str, Any], - param_states: Mapping[str, Any], - flax_optimizer_def: optimizers.OptimizerDefType = optimizers.sgd(0.1), -) -> FlaxOptimTrainState: - """Helper to construct train state for testing.""" - optimizer = optimizers.Optimizer( - flax_optimizer_def, - state=optimizers.OptimizerState(step=step, param_states=param_states), - target=params, - ) - return FlaxOptimTrainState(optimizer) - - -def shard_train_state( - *, - train_state: FlaxOptimTrainState, - global_mesh: Optional[jax.sharding.Mesh], - mesh_axes: Optional[PartitionSpec], -) -> FlaxOptimTrainState: - """Helper to construct a sharded train state from NumPy arrays.""" - return jax.tree.map( - functools.partial( - create_sharded_array, global_mesh=global_mesh, mesh_axes=mesh_axes - ), - train_state, - is_leaf=lambda x: isinstance(x, np.ndarray), - ) - - -def host_count_to_layout(host_count: int) -> Tuple[int, int, int, int]: - """Determines the host layout from the host count.""" - return { - 1: (2, 2, 1, 2), - 2: (4, 2, 1, 2), - 4: (4, 4, 1, 2), - 8: (4, 8, 1, 2), - 16: (8, 8, 1, 2), - 32: (8, 16, 1, 2), - }[host_count] - - -class CheckpointsTest(parameterized.TestCase): - - def setUp(self): - super().setUp() - - self.num_experts = 32 - - # The dense model is the checkpointed model that we seek to restore as a - # sparse model. The dense train state does NOT need to be sharded because - # it is only used for saving and validation. - self.dense_model_train_state = make_train_state( - step=np.int32(42), - params={ - 'mlp': { - 'kernel': np.arange(128, dtype=np.float32).reshape((8, 16)), - }, - 'attention': { - 'kernel': np.arange(64, dtype=np.float32).reshape((8, 8)), - }, - }, - param_states={ - 'mlp': { - 'kernel': 2 * np.arange(64, dtype=np.uint8), - }, - 'attention': { - 'kernel': 3 * np.arange(64, dtype=np.uint8), - }, - }, - ) - self.dense_model_mesh_axes = make_train_state( - step=None, - params={ - 'mlp': { - 'kernel': PartitionSpec(None, 'model'), - }, - 'attention': { - 'kernel': PartitionSpec(None, 'model'), - }, - }, - param_states={ - 'mlp': { - 'kernel': None, - }, - 'attention': { - 'kernel': None, - }, - }, - ) - - # The sparse model is the model that we want to restore into. It has two - # differences relative to the dense model: - # (1) 'mlp' --> 'expert' - # (2) 'expert' kernel has self.num_experts copies of the 'mlp' parameters. - # We will need to shard this train state into a JAX Array. - self.sparse_model_train_state = make_train_state( - step=np.int32(42), - params={ - 'expert': { - 'kernel': np.repeat( - np.expand_dims( - np.arange(128, dtype=np.float32).reshape((8, 16)), - axis=0, - ), - self.num_experts, - axis=0, - ), - }, - 'attention': { - 'kernel': np.arange(64, dtype=np.float32).reshape((8, 8)), - }, - }, - param_states={ - 'expert': { - 'kernel': np.repeat( - np.expand_dims(2 * np.arange(64, dtype=np.uint8), axis=0), - self.num_experts, - axis=0, - ), - }, - 'attention': { - 'kernel': 3 * np.arange(64, dtype=np.uint8), - }, - }, - ) - # Axes are the same as the dense model axes, except that we have an - # additional 'expert' axis for the expert kernels. - self.sparse_model_mesh_axes = make_train_state( - step=None, - params={ - 'expert': { - 'kernel': PartitionSpec('expert', None, 'model'), - }, - 'attention': { - 'kernel': PartitionSpec(None, 'model'), - }, - }, - param_states={ - 'expert': { - 'kernel': PartitionSpec('expert', None), - }, - 'attention': { - 'kernel': None, - }, - }, - ) - - self.ds = tf.data.Dataset.range(1024) - - self.checkpoints_dir = self.create_tempdir() - self.tmp_dir = self.checkpoints_dir.full_path - - @unittest.skipIf(jax.__version_info__ < (0, 4, 5), 'Test requires jax 0.4.5') - @mock.patch(f'{jax.process_index.__module__}.process_index') - @mock.patch('jax.devices') - @mock.patch('jax.local_devices') - def get_partitioner( - self, - process_index, - host_count, - num_partitions, - local_devices_fn, - devices_fn, - process_index_fn, - mesh_axes, - params_on_devices: bool = True, - ): - devices = test_utils.make_devices(*host_count_to_layout(host_count)) - devices_fn.return_value = devices - local_devices = [d for d in devices if d.process_index == 0] - local_devices_fn.return_value = local_devices - process_index_fn.return_value = process_index - num_partitions_to_mps = { - 1: (1, 1, 1, 1), - 2: (1, 1, 1, 2), - 4: (2, 1, 1, 2), - 16: (4, 2, 1, 2), - } - mesh = moe_partitioning.default_moe_mesh( - num_expert_partitions=self.num_experts, - num_partitions=num_partitions, - model_parallel_submesh=num_partitions_to_mps[num_partitions], - ) - local_chunker = base_partitioning.LocalChunker(mesh) - - class TestPartitioner(base_partitioning.BasePartitioner): - - def __init__(self): - self.move_params_to_devices_calls = 0 - super().__init__( - num_partitions, None, params_on_devices=params_on_devices - ) - - @property - def _local_chunker(self): - return local_chunker - - @property - def _mesh(self): - return mesh - - def partition( - self, - fn, - in_axis_resources, - out_axis_resources, - static_argnums=(), - donate_argnums=(), - ): - raise NotImplementedError - - def compile(self, partitioned_fn, *args): - raise NotImplementedError - - def move_params_to_devices(self, train_state, train_state_axes): - assert params_on_devices - return train_state - - def get_mesh_axes(self, train_state): - return mesh_axes - - return TestPartitioner() - - # pylint:disable=no-value-for-parameter - @mock.patch( - 'jax.experimental.multihost_utils.sync_global_devices', return_value=None - ) - @mock.patch('time.time', return_value=0) - @mock.patch('jax.process_count') - @mock.patch('jax.process_index') - def call_host_checkpointer( - self, - train_state, - process_index, - process_count, - partitioner, - fn, - save_dtype, - ds_iter, - mock_process_index, - mock_process_count, - unused_mock_host_time, - unused_mock_sync_devices, - restore_dtype=np.float32, - ): - mock_process_index.return_value = process_index - mock_process_count.return_value = process_count - - checkpointer = checkpoints.UpcycleCheckpointer( - train_state, - partitioner, - checkpoints_dir=self.tmp_dir, - num_experts=self.num_experts, - dataset_iterator=ds_iter, - save_dtype=save_dtype, - restore_dtype=restore_dtype, - ) - return fn(checkpointer) - - def validate_restore( - self, - host_count, - num_partitions, - step=42, - checkpoint_dataset=False, - expected_restore_dtype=np.float32, - lazy_parameters=False, - ): - """Verifies that UpcycleCheckpointer correctly sparsifies checkpoint.""" - global_mesh = test_utils.create_global_mesh( - host_count_to_layout(host_count), ('data', 'expert', 'model') - ) - - # We want to restore into the sparse model train state. - sharded_sparse_model_train_state = shard_train_state( - train_state=self.sparse_model_train_state, - global_mesh=global_mesh, - mesh_axes=self.sparse_model_mesh_axes, - ) - - # We map params of saved (dense) model to restored (sparse) model. - assignment_map = ( - (r'(.*)expert(.*)', r'\1mlp\2'), - (r'(.*)attention(.*)', r'\1attention\2'), - ) - # Turn `assignment_map` into a transformation function. - assignment_map_fn = functools.partial( - state_utils.apply_assignment_map, assignment_map=assignment_map - ) - - for i in range(host_count): - partitioner = self.get_partitioner( - i, - host_count, - num_partitions, - params_on_devices=not lazy_parameters, - mesh_axes=self.sparse_model_mesh_axes, - ) - - ds_iter = iter(self.ds) - - actual_train_state = self.call_host_checkpointer( - sharded_sparse_model_train_state, - i, - host_count, - partitioner, - lambda c: c.restore( # pylint: disable=g-long-lambda - step=step, - lazy_parameters=lazy_parameters, - state_transformation_fns=(assignment_map_fn,), - ), - np.float32, - ds_iter if checkpoint_dataset else None, - restore_dtype=expected_restore_dtype, - ) - if lazy_parameters: - actual_train_state = jax.tree.map(lambda x: x.get(), actual_train_state) - - # Validate. - - # Optimizer should be the same between actual (sparse) and original - # (dense) train states. - self.assertEqual( - actual_train_state._optimizer.optimizer_def, - self.dense_model_train_state._optimizer.optimizer_def, - ) - self.assertEqual(actual_train_state.step, step) - self.assertEqual(actual_train_state.step.dtype, np.int32) - self.assertEqual(actual_train_state._optimizer.state.step.dtype, np.int32) - - # Experts are sharded along the 'expert' axis, so each host loads a - # fraction of the expert parameters. - experts_per_host = self.num_experts // host_count - expected_per_host_params = { - 'expert': { - 'kernel': np.repeat( - np.expand_dims( - np.arange(128, dtype=np.float32).reshape((8, 16)), axis=0 - ), - experts_per_host, - axis=0, - ), - }, - 'attention': { - 'kernel': np.arange(64, dtype=np.float32).reshape((8, 8)), - }, - } - expected_per_host_param_states = { - 'expert': { - 'kernel': np.repeat( - np.expand_dims(2 * np.arange(64, dtype=np.uint8), axis=0), - experts_per_host, - axis=0, - ), - }, - 'attention': { - 'kernel': 3 * np.arange(64, dtype=np.uint8), - }, - } - - jax.tree.map( - np.testing.assert_array_equal, - actual_train_state.params, - expected_per_host_params, - ) - jax.tree.map( - np.testing.assert_array_equal, - actual_train_state.param_states, - expected_per_host_param_states, - ) - - self.assertEqual( - actual_train_state.param_states['attention']['kernel'].dtype, np.uint8 - ) - self.assertEqual( - actual_train_state.param_states['expert']['kernel'].dtype, np.uint8 - ) - - self.assertSameElements( - actual_train_state.params, ('attention', 'expert') - ) - - self.assertTrue( - all( - jax.tree.leaves( - jax.tree.map( - lambda x: x.dtype == expected_restore_dtype, - actual_train_state.params, - ) - ) - ) - ) - - expected_params = sharded_sparse_model_train_state.params - expected_param_states = sharded_sparse_model_train_state.param_states - - mlp_slice = partitioner.get_local_chunk_info( - expected_params['expert']['kernel'].shape, ('expert', None, 'model') - ).slice - np.testing.assert_equal( - actual_train_state.params['expert']['kernel'], - expected_params['expert']['kernel'][mlp_slice], - ) - - attn_slice = partitioner.get_local_chunk_info( - expected_params['attention']['kernel'].shape, (None, 'model') - ).slice - np.testing.assert_equal( - actual_train_state.params['attention']['kernel'], - expected_params['attention']['kernel'][attn_slice], - ) - - mlp_state_slice = partitioner.get_local_chunk_info( - expected_param_states['expert']['kernel'].shape, ('expert', None) - ).slice - np.testing.assert_equal( - actual_train_state.param_states['expert']['kernel'], - expected_param_states['expert']['kernel'][mlp_state_slice], - ) - - if checkpoint_dataset: - ds_shard_id = partitioner.get_data_layout().shard_id - # The next value from the restored iterator should equal the replica - # set id. - self.assertEqual(next(ds_iter).numpy(), ds_shard_id) - - def save( - self, - host_count, - num_partitions, - step=42, - save_dtype=np.float32, - checkpoint_dataset=False, - disable_partitioning=False, - ): - """We do not validate saves; UpcycleCheckpointer only overwrites restore.""" - # We save a dense model. We will try to restore it as a sparse model. - global_mesh = test_utils.create_global_mesh( - host_count_to_layout(host_count), ('data', 'model') - ) - - sharded_dense_model_train_state = shard_train_state( - train_state=self.dense_model_train_state, - global_mesh=global_mesh, - mesh_axes=self.dense_model_mesh_axes, - ) - - params = sharded_dense_model_train_state.params - param_states = sharded_dense_model_train_state.param_states - optimizer_def = sharded_dense_model_train_state._optimizer.optimizer_def - # Update these on each save. - step = np.int32(step) - - # Save the parameters and optimizer states. - # Each host sets its partition to its host number + 1. - # Go in reverse since host 0 renames the directory. - for i in reversed(range(host_count)): - partitioner = self.get_partitioner( - i, - host_count, - num_partitions, - mesh_axes=jax.tree.map(lambda x: None, self.dense_model_mesh_axes) - if disable_partitioning - else self.dense_model_mesh_axes, - ) - data_layout = partitioner.get_data_layout() - ds_shard_id = data_layout.shard_id - - mlp_chunk = partitioner.get_local_chunk_info( - params['mlp']['kernel'].shape, (None, 'model') - ) - attn_chunk = partitioner.get_local_chunk_info( - params['attention']['kernel'].shape, (None, 'model') - ) - - ds_iter = iter(self.ds) - - # pylint:disable=cell-var-from-loop - def _save_ckpt(checkpointer): - # Set the checkpoint so that the next value on restore will be the - # replica set id. - for _ in range(ds_shard_id): - next(ds_iter) - - train_state = make_train_state( - step=step, - # We save the dense model params. - params={ - 'mlp': { - 'kernel': params['mlp']['kernel'][mlp_chunk.slice], - }, - 'attention': { - 'kernel': params['attention']['kernel'][attn_chunk.slice], - }, - }, - param_states=param_states, - flax_optimizer_def=optimizer_def, - ) - checkpointer.save(train_state) - - # pylint:enable=cell-var-from-loop - - # Call host checkpointer with dense model train state. - self.call_host_checkpointer( - sharded_dense_model_train_state, - i, - host_count, - partitioner, - _save_ckpt, - save_dtype, - ds_iter if checkpoint_dataset else None, - ) - - # (host_count, num_partitions) - TOPOLOGIES = [ - (1, 1), # 1 host, 1 partition - (1, 2), # 1 host, 2 partitions - (2, 1), # 2 hosts, 1 partition - (2, 2), # 2 hosts, 2 partitions - (4, 4), # 4 hosts, 4 partitions - (4, 1), # 4 hosts, 1 partition - (4, 2), # 4 hosts, 2 partitions - (8, 2), # 8 hosts, 2 partitions - ] - - DTYPES = [ - jnp.int32, - jnp.float32, - jnp.bfloat16, - jnp.uint32, - jnp.int64, - jnp.float64, - ] - - @parameterized.parameters(itertools.product(TOPOLOGIES, TOPOLOGIES)) - def test_save_restore(self, save_topology, restore_topology): - self.save(*save_topology) - self.validate_restore(*restore_topology) - - @parameterized.parameters(itertools.product(TOPOLOGIES, TOPOLOGIES)) - def test_save_restore_lazy(self, save_topology, restore_topology): - self.save(*save_topology) - self.validate_restore(*restore_topology, lazy_parameters=True) - - @parameterized.parameters(TOPOLOGIES) - def test_save_restore_dataset(self, *topology): - # Note that we must use the same number of replica sets on save/restore. - self.save(*topology, checkpoint_dataset=True) - self.validate_restore(*topology, checkpoint_dataset=True) - - @parameterized.parameters(itertools.product(DTYPES, DTYPES)) - def test_save_as_type(self, save_dtype, restore_dtype): - self.save(1, 1, save_dtype=save_dtype) - self.validate_restore(1, 1, expected_restore_dtype=restore_dtype) - - @parameterized.parameters(TOPOLOGIES) - def test_save_non_partitioned_restore_partitioned(self, *restore_topology): - # Save without partitioning. - self.save(2, 1, disable_partitioning=True) - # Restore with partitioning. - self.validate_restore(*restore_topology) - - -if __name__ == '__main__': - absltest.main() diff --git a/t5x-main/t5x/contrib/moe/configs/__init__.py b/t5x-main/t5x/contrib/moe/configs/__init__.py deleted file mode 100644 index a52d4f9529506a53a19a2903bc0796383eb56b78..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/moe/configs/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""This empty file is needed for loading the gin files in this directory.""" diff --git a/t5x-main/t5x/contrib/moe/configs/runs/__init__.py b/t5x-main/t5x/contrib/moe/configs/runs/__init__.py deleted file mode 100644 index a52d4f9529506a53a19a2903bc0796383eb56b78..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/moe/configs/runs/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""This empty file is needed for loading the gin files in this directory.""" diff --git a/t5x-main/t5x/contrib/moe/configs/runs/continue_pretrain.gin b/t5x-main/t5x/contrib/moe/configs/runs/continue_pretrain.gin deleted file mode 100644 index 68b45b48ed460fe035232dde24152039df6253cf..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/moe/configs/runs/continue_pretrain.gin +++ /dev/null @@ -1,28 +0,0 @@ -# Continue a Mixture of Experts pre-training run. -# -# See t5x/contrib/moe/configs/runs/pretrain.gin for instructions. -# -# You must also include bindings for MODEL. -# -# Required to be set: -# -# - NUM_MODEL_PARTITIONS or MODEL_PARALLEL_SUBMESH (only specify one) -# - MIXTURE_OR_TASK_NAME -# - TASK_FEATURE_LENGTHS -# - TRAIN_STEPS -# - INITIAL_CHECKPOINT_PATH -# - MODEL_DIR -# -# You can also specify the upper bound for the size of the expert parallel -# submesh by overriding NUM_EXPERT_PARTITIONS, which defaults to NUM_EXPERTS. - -from __gin__ import dynamic_registration - -from t5x import utils - -include 't5x/contrib/moe/configs/runs/pretrain.gin' - -utils.RestoreCheckpointConfig: - mode = 'specific' - path = %INITIAL_CHECKPOINT_PATH - dtype = 'float32' diff --git a/t5x-main/t5x/contrib/moe/configs/runs/eval.gin b/t5x-main/t5x/contrib/moe/configs/runs/eval.gin deleted file mode 100644 index a066a8a38908962f3b5b82a2936095c08bcbcce0..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/moe/configs/runs/eval.gin +++ /dev/null @@ -1,47 +0,0 @@ -# Evaluate a Mixture of Experts model. -# -# -# You must also include bindings for MODEL. -# -# Required to be set: -# -# - NUM_MODEL_PARTITIONS or MODEL_PARALLEL_SUBMESH (only specify one) -# - MIXTURE_OR_TASK_NAME -# - CHECKPOINT_PATH -# - EVAL_OUTPUT_DIR -# -# You can also specify the upper bound for the size of the expert parallel -# submesh by overriding NUM_EXPERT_PARTITIONS, which defaults to NUM_EXPERTS. -# -# Commonly overridden options (see also t5x/configs/runs/eval.gin): -# -# - DROPOUT_RATE -# - BATCH_SIZE - -from __gin__ import dynamic_registration - -import __main__ as eval_script - -from t5x.contrib.moe import partitioning as moe_partitioning -from t5x import utils - -include 't5x/configs/runs/eval.gin' - -# One of these should be overridden. -NUM_MODEL_PARTITIONS = None -MODEL_PARALLEL_SUBMESH = None - -# Override to decrease the number of expert partitions. This is only an upper -# bound. Must be <= NUM_EXPERTS. Fewer expert partitions places more experts on -# the same device, requiring more expert replicas and greater memory overhead, -# but will reduce all-to-all communication costs. -NUM_EXPERT_PARTITIONS = %NUM_EXPERTS - -# We use the MoE partitioner. -eval_script.evaluate.partitioner = @moe_partitioning.MoePjitPartitioner() -moe_partitioning.MoePjitPartitioner: - num_expert_partitions = %NUM_EXPERT_PARTITIONS - num_partitions = %NUM_MODEL_PARTITIONS - model_parallel_submesh = %MODEL_PARALLEL_SUBMESH - -utils.DatasetConfig.batch_size = %BATCH_SIZE diff --git a/t5x-main/t5x/contrib/moe/configs/runs/export.gin b/t5x-main/t5x/contrib/moe/configs/runs/export.gin deleted file mode 100644 index eb651435793394ac00bc8a0326919a19809b63e5..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/moe/configs/runs/export.gin +++ /dev/null @@ -1,47 +0,0 @@ -# Mixture of Experts model defaults for single_core_export.py. -# -# You must also include bindings for MODEL. -# -# Required to be set: -# -# - NUM_MODEL_PARTITIONS or MODEL_PARALLEL_SUBMESH (only specify one) -# - TASK_FEATURE_LENGTHS -# - CHECKPOINT_PATH -# - INFER_OUTPUT_DIR -# -# Commonly overridden options (see also t5x/configs/runs/export.gin): -# -# warmup_examples: Optional[List[str]] = None -# jit_compile: bool = False - -from __gin__ import dynamic_registration - -from t5x import export_lib -from t5x.contrib.moe import models -from t5x.contrib.moe import partitioning as moe_partitioning - -include 't5x/configs/runs/export.gin' - - -# Only one of these should be specified. -NUM_MODEL_PARTITIONS = None -MODEL_PARALLEL_SUBMESH = None - -# Fix the number of expert partitions to 1; i.e. all devices hold copies of all -# experts. For multi-core export, we can increase this to partition experts -# across available devices. -NUM_EXPERT_PARTITIONS = 1 - -# We use the MoE partitioner. -export_lib.save.partitioner = @moe_partitioning.MoePjitPartitioner() - -moe_partitioning.MoePjitPartitioner: - num_expert_partitions = %NUM_EXPERT_PARTITIONS - num_partitions = %NUM_MODEL_PARTITIONS - model_parallel_submesh = %MODEL_PARALLEL_SUBMESH - params_on_devices = True - -# And the MoE encoder-decoder model. -models.MoeEncoderDecoderModel.predict_batch_with_aux: - num_decodes = %BEAM_SIZE - return_all_decodes = True diff --git a/t5x-main/t5x/contrib/moe/configs/runs/export_seqio.gin b/t5x-main/t5x/contrib/moe/configs/runs/export_seqio.gin deleted file mode 100644 index e51666f5cd3965469a55a8b978c2134ababebad8..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/moe/configs/runs/export_seqio.gin +++ /dev/null @@ -1,32 +0,0 @@ -# Mixture of Experts model defaults for export with seqio. -# -# You must also include bindings for MODEL. -# -# Required to be set: -# -# - NUM_MODEL_PARTITIONS or MODEL_PARALLEL_SUBMESH (only specify one) -# - MIXTURE_OR_TASK_NAME -# - TASK_FEATURE_LENGTHS -# - CHECKPOINT_PATH -# - INFER_OUTPUT_DIR - -from __gin__ import dynamic_registration - -from t5x import export_lib - -include 't5x/contrib/moe/configs/runs/export.gin' - - -MIXTURE_OR_TASK_NAME = %gin.REQUIRED - -export_lib.save: - create_preprocessor_fn = @export_lib.create_preprocessor_from_task - mixture_or_task_name = %MIXTURE_OR_TASK_NAME - output_features = None - -export_lib.create_preprocessor_from_task: - model = %MODEL - task_feature_lengths = %TASK_FEATURE_LENGTHS - task_name = %MIXTURE_OR_TASK_NAME - serialized_examples = True - run_precache = False diff --git a/t5x-main/t5x/contrib/moe/configs/runs/finetune.gin b/t5x-main/t5x/contrib/moe/configs/runs/finetune.gin deleted file mode 100644 index 19cca255835e37caf5c093dba5799e822102f3c0..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/moe/configs/runs/finetune.gin +++ /dev/null @@ -1,68 +0,0 @@ -# Fine-tune a Mixture of Experts model. -# -# This file allows for fine-tuning with data, expert and model parallelism. -# -# -# You must also include bindings for MODEL. -# -# Required to be set: -# -# - NUM_MODEL_PARTITIONS or MODEL_PARALLEL_SUBMESH (only specify one) -# - MIXTURE_OR_TASK_NAME -# - TASK_FEATURE_LENGTHS -# - TRAIN_STEPS # includes pretrain steps -# - MODEL_DIR -# - INITIAL_CHECKPOINT_PATH -# -# You can also specify the upper bound for the size of the expert parallel -# submesh by overriding NUM_EXPERT_PARTITIONS, which defaults to NUM_EXPERTS. -# -# Commonly overridden options (see also t5x/configs/runs/finetune.gin): -# -# - EXPERT_DROPOUT_RATE -# - DROPOUT_RATE -# - BATCH_SIZE -# - MoeTrainer.num_microbatches - -from __gin__ import dynamic_registration - -import __main__ as train_script - -from t5x.contrib.moe import partitioning as moe_partitioning -from t5x.contrib.moe import trainer as moe_trainer -from t5x import utils - -include 't5x/configs/runs/finetune.gin' - -EXPERT_DROPOUT_RATE = %DROPOUT_RATE # provided by t5x/configs/runs/finetune.gin - -# One of these should be overridden. -NUM_MODEL_PARTITIONS = None -MODEL_PARALLEL_SUBMESH = None - -# Override to decrease the number of expert partitions. This is only an upper -# bound. Must be <= NUM_EXPERTS. Fewer expert partitions places more experts on -# the same device, requiring more expert replicas and greater memory overhead, -# but will reduce all-to-all communication costs. -NUM_EXPERT_PARTITIONS = %NUM_EXPERTS - -# We use the MoE partitioner. -train_script.train.partitioner = @moe_partitioning.MoePjitPartitioner() -moe_partitioning.MoePjitPartitioner: - num_expert_partitions = %NUM_EXPERT_PARTITIONS - num_partitions = %NUM_MODEL_PARTITIONS - model_parallel_submesh = %MODEL_PARALLEL_SUBMESH - -# And the MoE trainer. -train_script.train.trainer_cls = @moe_trainer.MoeTrainer -moe_trainer.MoeTrainer: - num_microbatches = None - learning_rate_fn = @utils.create_learning_rate_scheduler() - num_expert_partitions = %NUM_EXPERT_PARTITIONS -utils.create_learning_rate_scheduler: - factors = 'constant' - base_learning_rate = 0.001 - warmup_steps = 1000 - -# Checkpoint slightly more often than fine-tuning defaults. -utils.SaveCheckpointConfig.period = 2000 diff --git a/t5x-main/t5x/contrib/moe/configs/runs/infer.gin b/t5x-main/t5x/contrib/moe/configs/runs/infer.gin deleted file mode 100644 index a0cfca808337fac8f5f8c42caf34f35a8072fd32..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/moe/configs/runs/infer.gin +++ /dev/null @@ -1,48 +0,0 @@ -# Run inference with a Mixture of Experts model. -# -# -# You must also include bindings for MODEL. -# -# Required to be set: -# -# - NUM_MODEL_PARTITIONS or MODEL_PARALLEL_SUBMESH (only specify one) -# - MIXTURE_OR_TASK_NAME -# - TASK_FEATURE_LENGTHS -# - CHECKPOINT_PATH -# - INFER_OUTPUT_DIR -# -# You can also specify the upper bound for the size of the expert parallel -# submesh by overriding NUM_EXPERT_PARTITIONS, which defaults to NUM_EXPERTS. -# -# Commonly overridden options (see also t5x/configs/runs/infer.gin): -# -# - DROPOUT_RATE -# - BATCH_SIZE - -from __gin__ import dynamic_registration - -import __main__ as infer_script - -from t5x.contrib.moe import partitioning as moe_partitioning -from t5x import utils - -include 't5x/configs/runs/infer.gin' - -# One of these should be overridden. -NUM_MODEL_PARTITIONS = None -MODEL_PARALLEL_SUBMESH = None - -# Override to decrease the number of expert partitions. This is only an upper -# bound. Must be <= NUM_EXPERTS. Fewer expert partitions places more experts on -# the same device, requiring more expert replicas and greater memory overhead, -# but will reduce all-to-all communication costs. -NUM_EXPERT_PARTITIONS = %NUM_EXPERTS - -# We use the MoE partitioner. -infer_script.infer.partitioner = @moe_partitioning.MoePjitPartitioner() -moe_partitioning.MoePjitPartitioner: - num_expert_partitions = %NUM_EXPERT_PARTITIONS - num_partitions = %NUM_MODEL_PARTITIONS - model_parallel_submesh = %MODEL_PARALLEL_SUBMESH - -utils.DatasetConfig.batch_size = %BATCH_SIZE diff --git a/t5x-main/t5x/contrib/moe/configs/runs/infer_from_tfexample_file.gin b/t5x-main/t5x/contrib/moe/configs/runs/infer_from_tfexample_file.gin deleted file mode 100644 index f83f0b30122c17bd651beec4e434e98ad6aa629e..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/moe/configs/runs/infer_from_tfexample_file.gin +++ /dev/null @@ -1,47 +0,0 @@ -# Mixture of Experts defaults for infer.py if using a TFExample file as input. -# -# -# The features from each TFExample are tokenized using the model's vocabulary. -# By default, the inputs feature is assumed to be keyed as 'inputs', but this -# can be overridden with `create_task_from_tfexample_file.inputs_key`. -# -# You must include a binding for MODEL. -# -# Required to be set: -# -# - TF_EXAMPLE_FILE_PATHS: The path to read TF Examples from. -# - TF_EXAMPLE_FILE_TYPE: The type of file to read TF Examples from. Currently -# supported: 'tfrecord', 'recordio', 'sstable'. -# - NUM_MODEL_PARTITIONS or MODEL_PARALLEL_SUBMESH (only specify one) -# - FEATURE_LENGTHS: The maximum length per feature in the TF Examples. -# - CHECKPOINT_PATH: The model checkpoint to use for inference -# - INFER_OUTPUT_DIR: The dir to write results to. -# -# See also t5x/configs/runs/infer_from_tfexample_file.gin for commonly -# overridden options. - -from __gin__ import dynamic_registration - -import __main__ as infer_script -from t5x.contrib.moe import partitioning as moe_partitioning - -include 't5x/configs/runs/infer_from_tfexample_file.gin' - -infer_script.infer.partitioner = @moe_partitioning.MoePjitPartitioner() - -# One, and only one, of these should be specified. -NUM_MODEL_PARTITIONS = 1 -MODEL_PARALLEL_SUBMESH = None - -# Override to decrease the number of expert partitions. This is only an upper -# bound. Must be <= NUM_EXPERTS. Fewer expert partitions places more experts on -# the same device, requiring more expert replicas and greater memory overhead, -# but will reduce all-to-all communication costs. -NUM_EXPERT_PARTITIONS = %NUM_EXPERTS - -# We use the MoE partitioner. -train_script.precompile.partitioner = @moe_partitioning.MoePjitPartitioner() -moe_partitioning.MoePjitPartitioner: - num_expert_partitions = %NUM_EXPERT_PARTITIONS - num_partitions = %NUM_MODEL_PARTITIONS - model_parallel_submesh = %MODEL_PARALLEL_SUBMESH diff --git a/t5x-main/t5x/contrib/moe/configs/runs/precompile.gin b/t5x-main/t5x/contrib/moe/configs/runs/precompile.gin deleted file mode 100644 index 81b3f2fd2cf705f83bfc722bbe0f174a236ada27..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/moe/configs/runs/precompile.gin +++ /dev/null @@ -1,43 +0,0 @@ -# Mixture of Experts defaults for precompile mode in main.py. -# -# You must also include a binding for MODEL. -# -# Required to be set: -# -# - NUM_MODEL_PARTITIONS or MODEL_PARALLEL_SUBMESH (only specify one) -# - MIXTURE_OR_TASK_NAME -# - TASK_FEATURE_LENGTHS -# - TRAIN_STEPS -# - MODEL_DIR: # automatically set when using xm_launch -# -# Commonly overridden options (see also t5x/configs/runs/precompile.gin): -# -# - USE_CACHED_TASKS -# - BATCH_SIZE -# -# You can also specify the upper bound for the size of the expert parallel -# submesh by overriding NUM_EXPERT_PARTITIONS, which defaults to NUM_EXPERTS. - -from __gin__ import dynamic_registration - -import __main__ as train_script -from t5x.contrib.moe import partitioning as moe_partitioning - -include 't5x/configs/runs/precompile.gin' - -# One of these should be overridden. -NUM_MODEL_PARTITIONS = None -MODEL_PARALLEL_SUBMESH = None - -# Override to decrease the number of expert partitions. This is only an upper -# bound. Must be <= NUM_EXPERTS. Fewer expert partitions places more experts on -# the same device, requiring more expert replicas and greater memory overhead, -# but will reduce all-to-all communication costs. -NUM_EXPERT_PARTITIONS = %NUM_EXPERTS - -# We use the MoE partitioner. -train_script.precompile.partitioner = @moe_partitioning.MoePjitPartitioner() -moe_partitioning.MoePjitPartitioner: - num_expert_partitions = %NUM_EXPERT_PARTITIONS - num_partitions = %NUM_MODEL_PARTITIONS - model_parallel_submesh = %MODEL_PARALLEL_SUBMESH diff --git a/t5x-main/t5x/contrib/moe/configs/runs/pretrain.gin b/t5x-main/t5x/contrib/moe/configs/runs/pretrain.gin deleted file mode 100644 index 0f1150bf2ee10100d9fae1ac1b718b5969d54638..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/moe/configs/runs/pretrain.gin +++ /dev/null @@ -1,66 +0,0 @@ -# Pre-train a Mixture of Experts model. -# -# This file allows for pre-training with data, expert and model parallelism. To -# use model parallelism, set NUM_MODEL_PARTITIONS > 1. -# -# -# You must also include bindings for MODEL. -# -# Required to be set: -# -# - NUM_MODEL_PARTITIONS or MODEL_PARALLEL_SUBMESH (only specify one) -# - MIXTURE_OR_TASK_NAME -# - TASK_FEATURE_LENGTHS -# - TRAIN_STEPS -# - MODEL_DIR -# -# You can also specify the upper bound for the size of the expert parallel -# submesh by overriding NUM_EXPERT_PARTITIONS, which defaults to NUM_EXPERTS. -# -# Commonly overridden options (see also t5x/configs/runs/pretrain.gin): -# -# - BATCH_SIZE -# - MoeTrainer.num_microbatches -# - DROPOUT_RATE - -from __gin__ import dynamic_registration - -import __main__ as train_script - -from t5x.contrib.moe import partitioning as moe_partitioning -from t5x.contrib.moe import trainer as moe_trainer -from t5x import utils - -include 't5x/configs/runs/pretrain.gin' - -# One of these should be overridden. -NUM_MODEL_PARTITIONS = None -MODEL_PARALLEL_SUBMESH = None - -# Override to decrease the number of expert partitions. This is only an upper -# bound. Must be <= NUM_EXPERTS. Fewer expert partitions places more experts on -# the same device, requiring more expert replicas and greater memory overhead, -# but will reduce all-to-all communication costs. -NUM_EXPERT_PARTITIONS = %NUM_EXPERTS - -# We use the MoE partitioner. -train_script.train.partitioner = @moe_partitioning.MoePjitPartitioner() -moe_partitioning.MoePjitPartitioner: - num_expert_partitions = %NUM_EXPERT_PARTITIONS - num_partitions = %NUM_MODEL_PARTITIONS - model_parallel_submesh = %MODEL_PARALLEL_SUBMESH - -# And the MoE trainer. -train_script.train.trainer_cls = @moe_trainer.MoeTrainer -moe_trainer.MoeTrainer: - num_microbatches = None - learning_rate_fn = @utils.create_learning_rate_scheduler() - num_expert_partitions = %NUM_EXPERT_PARTITIONS -utils.create_learning_rate_scheduler: - factors = 'constant * rsqrt_decay' - base_learning_rate = 1.0 - warmup_steps = 10000 # 10k to keep consistent with T5/MTF defaults. - -# Keep slightly fewer checkpoints than pre-training defaults. -utils.SaveCheckpointConfig.period = 5000 -utils.SaveCheckpointConfig.keep = 20 \ No newline at end of file diff --git a/t5x-main/t5x/contrib/moe/configs/runs/sparse_upcycle.gin b/t5x-main/t5x/contrib/moe/configs/runs/sparse_upcycle.gin deleted file mode 100644 index 1ac7873075f7205315b00356eb6f8877e913909d..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/moe/configs/runs/sparse_upcycle.gin +++ /dev/null @@ -1,29 +0,0 @@ -# Sparsely upcycles a pretrained dense model. -# -# You must also include bindings for MODEL and NUM_EXPERTS (typically set by the -# model gin config). -# -# See t5x/contrib/moe/configs/runs/continue_pretrain.gin for other required -# bindings. - -from __gin__ import dynamic_registration - -import __main__ as train_script -from t5x import utils -from t5x.contrib.moe import checkpoints - -include 't5x/contrib/moe/configs/runs/continue_pretrain.gin' - -utils.RestoreCheckpointConfig: - fallback_to_scratch = True - checkpointer_cls = @checkpoints.UpcycleCheckpointer - assignment_map = ( - (r'target(.*)mlp\/expert(.*)', r'target\1mlp\2'), # Replace dense MLPs with sparse variants - (r'.*\/router\/.*', None), # Initialize router weights from scratch - (r'state\/param_states.*', None), # Initialize optimizer states from scratch - ) - -checkpoints.UpcycleCheckpointer.num_experts = %NUM_EXPERTS - -# Upcycle using JAX arrays. -train_script.train.use_jax_array = True diff --git a/t5x-main/t5x/contrib/moe/models.py b/t5x-main/t5x/contrib/moe/models.py deleted file mode 100644 index 005b54b4d9cb3bacba746df062177ff2c4bf92a7..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/moe/models.py +++ /dev/null @@ -1,425 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Provides model subclasses with Mixture of Experts support.""" - -from typing import Any, Callable, Dict, Mapping, MutableMapping, Optional, Tuple, Union - -import clu.metrics as clu_metrics -from flax import core as flax_core -from flax import linen as nn -from flax import traverse_util -from flax.core import scope as flax_scope -import jax -import jax.numpy as jnp -import seqio -from t5x import decoding -from t5x import losses -from t5x import metrics as metrics_lib -from t5x import models as base_models -from t5x import optimizers - -AveragePerStep = metrics_lib.AveragePerStep -DecodeFnCallable = base_models.DecodeFnCallable -FrozenVariableDict = flax_scope.FrozenVariableDict -MetricsMap = metrics_lib.MetricsMap -PyTree = base_models.PyTree -Sum = metrics_lib.Sum - -MOE_METRICS = ( - 'auxiliary_loss', - 'router_z_loss', - 'fraction_tokens_left_behind', - 'expert_usage', - 'router_confidence', -) - - -class MoeEncoderDecoderModel(base_models.EncoderDecoderModel): - """Encoder-decoder subclass which propagates MoE auxiliary loss & metrics.""" - - def __init__( - self, - module: nn.Module, - input_vocabulary: seqio.Vocabulary, - output_vocabulary: seqio.Vocabulary, - optimizer_def: optimizers.OptimizerDefType, - decode_fn: DecodeFnCallable = decoding.beam_search, - feature_converter_cls: Optional[ - Callable[..., seqio.FeatureConverter] - ] = None, - label_smoothing: float = 0.0, - z_loss: float = 0.0, - loss_normalizing_factor: Optional[ - Union[float, int, str, losses.SpecialLossNormalizingFactor] - ] = None, - aux_loss_factor: float = 0.0, - router_z_loss_factor: float = 0.0, - ): - super().__init__( - module=module, - input_vocabulary=input_vocabulary, - output_vocabulary=output_vocabulary, - optimizer_def=optimizer_def, - decode_fn=decode_fn, - feature_converter_cls=feature_converter_cls, - label_smoothing=label_smoothing, - z_loss=z_loss, - loss_normalizing_factor=loss_normalizing_factor, - ) - self._aux_loss_factor = aux_loss_factor - self._router_z_loss_factor = router_z_loss_factor - - def loss_fn( - self, - params: base_models.PyTree, - batch: Mapping[str, jnp.ndarray], - dropout_rng: Optional[jnp.ndarray], - ) -> Tuple[jnp.ndarray, MetricsMap]: - """Cross-entropy loss function with auxiliary MoE losses. - - Args: - params: Model parameters. - batch: Batch of training examples. - dropout_rng: Random number generator key for dropout. - - Returns: - - Model loss. - - Metrics. - """ - logits, state = self._compute_logits( - params, batch, dropout_rng, mutable=['intermediates'] - ) - return _moe_loss_fn( - batch, - logits, - state, - self._label_smoothing, - self._z_loss, - self._loss_normalizing_factor, - self._aux_loss_factor, - self._router_z_loss_factor, - ) - - def predict_batch_with_aux( # pylint: disable=useless-super-delegation - self, - params: PyTree, - batch: Mapping[str, jnp.ndarray], - rng: Optional[jax.Array] = None, - decoder_params: Optional[MutableMapping[str, Any]] = None, - return_all_decodes: bool = False, - num_decodes: int = 1, - prompt_with_targets: bool = False, - ) -> Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]: - """Predict with fast decoding beam search on a batch. - - This override is only included for dependency injection configurability - (e.g. gin). See parent method docstring for detailed description. - - Args: - params: Model parameters. - batch: Batch of inputs. - rng: RNG key to use during prediction. - decoder_params: Additional (model-independent) parameters for the decoder. - return_all_decodes: Whether to return the entire beam or just the top-1. - num_decodes: Number of beams to use in beam search. - prompt_with_targets: Whether to force decode decoder_inputs. - - Returns: - - Batch of predictions, with the entire beam if requested, - - Auxiliary dictionary of decoder scores. - """ - return super().predict_batch_with_aux( - params, - batch, - rng, - decoder_params, - return_all_decodes, - num_decodes, - prompt_with_targets, - ) - - -class MoeDecoderOnlyModel(base_models.DecoderOnlyModel): - """Decoder-only subclass which propagates MoE auxiliary loss and metrics.""" - - def __init__( - self, - module: nn.Module, - vocabulary: seqio.Vocabulary, - optimizer_def: optimizers.OptimizerDefType, - decode_fn: DecodeFnCallable = decoding.temperature_sample, - inputs_bidirectional_attention: bool = False, - feature_converter_cls: Optional[ - Callable[..., seqio.FeatureConverter] - ] = None, - label_smoothing: float = 0.0, - z_loss: float = 0.0, - loss_normalizing_factor: Optional[ - Union[float, int, str, losses.SpecialLossNormalizingFactor] - ] = None, - aux_loss_factor: float = 0.0, - router_z_loss_factor: float = 0.0, - ): - super().__init__( - module=module, - vocabulary=vocabulary, - optimizer_def=optimizer_def, - decode_fn=decode_fn, - inputs_bidirectional_attention=inputs_bidirectional_attention, - feature_converter_cls=feature_converter_cls, - label_smoothing=label_smoothing, - z_loss=z_loss, - loss_normalizing_factor=loss_normalizing_factor, - ) - self._aux_loss_factor = aux_loss_factor - self._router_z_loss_factor = router_z_loss_factor - - def loss_fn( - self, - params: base_models.PyTree, - batch: Mapping[str, jnp.ndarray], - dropout_rng: Optional[jnp.ndarray], - ) -> Tuple[jnp.ndarray, MetricsMap]: - """Cross-entropy loss function with auxiliary MoE losses. - - Args: - params: Model parameters. - batch: Batch of training examples. - dropout_rng: Random number generator key for dropout. - - Returns: - - Model loss. - - Metrics. - """ - logits, state = self._compute_logits( - params, batch, dropout_rng, mutable=['intermediates'] - ) - return _moe_loss_fn( - batch, - logits, - state, - self._label_smoothing, - self._z_loss, - self._loss_normalizing_factor, - self._aux_loss_factor, - self._router_z_loss_factor, - ) - - def predict_batch_with_aux( # pylint: disable=useless-super-delegation - self, - params: PyTree, - batch: Mapping[str, jnp.ndarray], - rng: Optional[jax.Array] = None, - *, - return_all_decodes: bool = False, - num_decodes: int = 1, - decoder_params: Optional[MutableMapping[str, Any]] = None, - ) -> Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]: - """Predict with prefix. - - This override is only included for dependency injection configurability - (e.g. gin). See parent method docstring for detailed description. - - Args: - params: Model parameters. - batch: Batch of inputs with the model features specified in - seqio.DecoderFeatureConverter. - rng: RNG key to use during prediction. - return_all_decodes: Whether to return the entire beam or just the top-1. - num_decodes: Number of decoded sequences to be returned. - decoder_params: Additional (model-independent) parameters for the decoder. - - Returns: - Sampled sequences of shape [batch, max_decode_length]. - """ - return super().predict_batch_with_aux( - params, - batch, - rng, - return_all_decodes=return_all_decodes, - num_decodes=num_decodes, - decoder_params=decoder_params, - ) - - -def _moe_loss_fn( - batch: Mapping[str, jnp.ndarray], - logits: jnp.ndarray, - state: flax_scope.FrozenVariableDict, - label_smoothing: float, - z_loss: float, - loss_normalizing_factor: Optional[float], - aux_loss_factor: float, - router_z_loss_factor: float, -) -> Tuple[jnp.ndarray, MetricsMap]: - """Computes combined cross-entropy and MoE auxiliary loss.""" - loss_normalizing_factor: Optional[ - Union[float, int, str, losses.SpecialLossNormalizingFactor] - ] - (loss_normalizing_factor, weights) = ( - losses.get_loss_normalizing_factor_and_weights( - loss_normalizing_factor, batch - ) - ) - - targets = batch['decoder_target_tokens'] - total_loss, z_loss, _ = losses.compute_weighted_cross_entropy( - logits, - targets=targets, - weights=weights, - label_smoothing=label_smoothing, - z_loss=z_loss, - loss_normalizing_factor=loss_normalizing_factor, - ) - - # Extract and add MoE losses to total loss. - diversity_metrics = _extract_diversity_metrics(state) - - aux_loss, router_z_loss = _expert_losses( - diversity_metrics, aux_loss_factor, router_z_loss_factor - ) - total_loss += aux_loss + router_z_loss - - metrics = base_models.compute_base_metrics( - logits=logits, - targets=targets, - mask=weights, - loss=total_loss, - z_loss=z_loss, - ) - metrics.update( - _expert_metrics( # pytype: disable=wrong-arg-types # jax-ndarray - diversity_metrics, - total_loss, - z_loss, - aux_loss, - router_z_loss, - num_tokens=targets.size, - ) - ) - - return total_loss, metrics - - -def _extract_diversity_metrics( - state: flax_scope.FrozenVariableDict, -) -> Dict[str, jnp.ndarray]: - """Extract average expert diversity metrics from sown state intermediates. - - Args: - state: Model state holding sown intermediate metrics. - - Returns: - Diversity metrics, averaged across MoE layers. - - Raises: - ValueError if unable to extract diversity metrics from model state. - """ - state_dict = traverse_util.flatten_dict(flax_core.unfreeze(state)) - - avg_metrics = {} - for metric in MOE_METRICS: - summed_metric = 0.0 - count = 0 - for path, value in state_dict.items(): - if path[-1] == metric: - summed_metric += jnp.asarray(value, dtype=jnp.float32) - count += 1 - - if count == 0: - raise ValueError( - f'Unable to find expert metric: {metric}. Please check that MoE ' - 'metrics and losses are correctly sown.' - ) - - avg_metrics[metric] = summed_metric / count - - return avg_metrics - - -def _expert_losses( - diversity_metrics: Mapping[str, jnp.ndarray], - auxiliary_loss_factor: float, - router_z_loss_factor: float, -) -> Tuple[float, float]: - """Summarizes per-layer MoE auxiliary losses. - - For auxiliary losses, we take the mean across MoE layers. - - Args: - diversity_metrics: Per-layer mixture of expert metrics. - auxiliary_loss_factor: Factor by which to scale auxiliary load balancing - loss for mixture of experts models. The raw auxiliary losses will be - summed and then scaled by this factor. - router_z_loss_factor: Factor by which to scale router z-loss for mixture of - experts models. - - Returns: - - Load balancing loss. - - Router z-loss. - """ - aux_loss = auxiliary_loss_factor * diversity_metrics['auxiliary_loss'].mean() - router_z_loss = ( - router_z_loss_factor * diversity_metrics['router_z_loss'].mean() - ) - return aux_loss, router_z_loss # pytype: disable=bad-return-type # jax-ndarray - - -def _expert_metrics( - diversity_metrics: Mapping[str, jnp.ndarray], - total_loss: float, - z_loss: float, - auxiliary_loss: float, - router_z_loss: float, - num_tokens: int, -) -> MetricsMap: - """Summarizes per-layer expert metrics for the entire model. - - The return metrics map will also contain overrides for the cross entropy loss - metrics to account for the MoE losses. - - Args: - diversity_metrics: Per-layer mixture of expert metrics. - total_loss: Total model loss. - z_loss: Output logits z-loss (not MoE specific). - auxiliary_loss: Auxiliary load balancing loss for MoE models. - router_z_loss: Router z-loss for MoE models. - num_tokens: Total number of target tokens. - - Returns: - Expert diversity metrics. - """ - cross_ent_loss = total_loss - z_loss - auxiliary_loss - router_z_loss - return { - 'experts/auxiliary_loss': AveragePerStep.from_model_output( - auxiliary_loss - ), - 'experts/router_z_loss': AveragePerStep.from_model_output(router_z_loss), - 'experts/fraction_tokens_left_behind': AveragePerStep.from_model_output( - diversity_metrics['fraction_tokens_left_behind'].mean() - ), - 'experts/expert_usage': AveragePerStep.from_model_output( - diversity_metrics['expert_usage'].mean() - ), - 'experts/router_confidence': AveragePerStep.from_model_output( - diversity_metrics['router_confidence'].mean() - ), - # Override vanilla T5 cross entropy loss metrics with corrected loss that - # accounts for MoE losses. - 'cross_ent_loss': metrics_lib.AveragePerStep(total=cross_ent_loss), - 'cross_ent_loss_per_all_target_tokens': clu_metrics.Average( # pytype: disable=wrong-arg-types # jnp-array - total=jnp.sum(cross_ent_loss), count=num_tokens - ), - } diff --git a/t5x-main/t5x/contrib/moe/models_test.py b/t5x-main/t5x/contrib/moe/models_test.py deleted file mode 100644 index b16f25b704951e8eab3cdd710007ec31b9ee353b..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/moe/models_test.py +++ /dev/null @@ -1,270 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for models.""" - -import functools -from unittest import mock - -from absl.testing import absltest -from clu import metrics as clu_metrics_lib -from flax import core as flax_core -import jax.numpy as jnp -import numpy as np -from t5x import decoding -from t5x import metrics as metrics_lib -from t5x.contrib.moe import models - -Accuracy = clu_metrics_lib.Accuracy -Average = clu_metrics_lib.Average -AveragePerStep = metrics_lib.AveragePerStep -FrozenDict = flax_core.frozen_dict.FrozenDict - - -class ModelsTest(absltest.TestCase): - - def test_expert_losses(self): - diversity_metrics = { - 'auxiliary_loss': jnp.array([1.0, 2.0]), - 'router_z_loss': jnp.array([0.0, 1.0]), - 'fraction_tokens_left_behind': jnp.array([0.5, 0.5]), - 'expert_usage': jnp.array([0.5, 0.5]), - 'router_confidence': jnp.array([0.5, 0.5]), - } - aux_loss, router_z_loss = models._expert_losses( - diversity_metrics, auxiliary_loss_factor=0.1, router_z_loss_factor=10 - ) - - self.assertEqual(aux_loss, 0.15) - self.assertEqual(router_z_loss, 5.0) - - def test_expert_metrics(self): - diversity_metrics = { - 'auxiliary_loss': jnp.array([1.0, 2.0]), - 'router_z_loss': jnp.array([0.0, 1.0]), - 'fraction_tokens_left_behind': jnp.array([1.0, 0.5]), - 'expert_usage': jnp.array([0.7, 0.5]), - 'router_confidence': jnp.array([0.5, 0.5]), - } - actual_metrics = models._expert_metrics( - diversity_metrics, - total_loss=100.0, - z_loss=1.0, - auxiliary_loss=3.0, - router_z_loss=7.0, - num_tokens=2, - ) - actual_metrics = metrics_lib.set_step_metrics_num_steps(actual_metrics, 1) - actual_computed_metrics = { - k: v.compute() for k, v in actual_metrics.items() - } - - expected_metrics = { - 'cross_ent_loss': 89.0, - 'cross_ent_loss_per_all_target_tokens': 44.5, - 'experts/auxiliary_loss': 3.0, - 'experts/expert_usage': 0.6, - 'experts/fraction_tokens_left_behind': 0.75, - 'experts/router_confidence': 0.5, - 'experts/router_z_loss': 7.0, - } - self.assertEqual(actual_computed_metrics, expected_metrics) - - def test_extract_diversity_metrics(self): - state = flax_core.freeze({ - 'intermediates': { - 'moe_layer_0': { - 'auxiliary_loss': (jnp.array([[0.2]]), jnp.array([[0.1]])), - 'router_z_loss': (jnp.array([[0.1]]), jnp.array([[0.2]])), - 'router_confidence': (jnp.array([[0.4]]), jnp.array([[0.2]])), - 'expert_usage': (jnp.array([[0.9]]), jnp.array([[0.2]])), - 'fraction_tokens_left_behind': ( - jnp.array([[0.1]]), - jnp.array([[0.2]]), - ), - }, - 'moe_layer_1': { - 'auxiliary_loss': (jnp.array([[0.2]]), jnp.array([[0.0]])), - 'router_z_loss': (jnp.array([[0.1]]), jnp.array([[0.0]])), - 'router_confidence': (jnp.array([[0.4]]), jnp.array([[0.5]])), - 'expert_usage': (jnp.array([[0.9]]), jnp.array([[0.8]])), - 'fraction_tokens_left_behind': ( - jnp.array([[0.1]]), - jnp.array([[0.3]]), - ), - }, - } - }) - extracted_metrics = models._extract_diversity_metrics(state) - - expected_raw_metrics = { - 'auxiliary_loss': jnp.array([[[0.2]], [[0.05]]], dtype=jnp.float32), - 'router_z_loss': jnp.array([[[0.1]], [[0.1]]], dtype=jnp.float32), - 'fraction_tokens_left_behind': jnp.array( - [[[0.1]], [[0.25]]], dtype=jnp.float32 - ), - 'expert_usage': jnp.array([[[0.9]], [[0.5]]], dtype=jnp.float32), - 'router_confidence': jnp.array([[[0.4]], [[0.35]]], dtype=jnp.float32), - } - for metric, expected_value in expected_raw_metrics.items(): - np.testing.assert_allclose(extracted_metrics[metric], expected_value) - - def test_extract_from_non_expert_model(self): - empty_state = FrozenDict({'intermediates': {}}) - with self.assertRaisesRegex(ValueError, 'Unable to find expert metric'): - models._extract_diversity_metrics(empty_state) - - def test_encoder_decoder_model(self): - encoder_input_tokens = jnp.ones((2, 3)) - decoder_input_tokens = jnp.array([[1, 2, 1, 0], [0, 1, 0, 2]]) - decoder_target_tokens = jnp.array([[1, 2, 1, 0], [0, 1, 0, 2]]) - decoder_loss_weights = jnp.array([[1, 1, 1, 0], [0, 1, 0, 1]]) - dummy_logits = jnp.arange(0, 24).reshape((2, 4, 3)) - params = {'foo': jnp.zeros(3)} - - mock_transformer = mock.Mock() - mock_transformer.apply.return_value = dummy_logits - mock_transformer.dtype = jnp.float32 - - batch = { - 'encoder_input_tokens': encoder_input_tokens, - 'decoder_input_tokens': decoder_input_tokens, - 'decoder_target_tokens': decoder_target_tokens, - 'decoder_loss_weights': decoder_loss_weights, - } - - def mock_init(self): - self.module = mock_transformer - - with mock.patch.object( - models.MoeEncoderDecoderModel, '__init__', new=mock_init - ): - model = models.MoeEncoderDecoderModel() - result = model.score_batch(params, batch) - - mock_transformer.apply.assert_called_with( - {'params': params}, - encoder_input_tokens, - decoder_input_tokens, - decoder_target_tokens, - encoder_segment_ids=None, - decoder_segment_ids=None, - encoder_positions=None, - decoder_positions=None, - decode=False, - enable_dropout=False, - rngs=None, - mutable=False, - ) - np.testing.assert_allclose(result, [-3.2228181, -1.8152122], rtol=1e-5) - - def test_decoder_only_model(self): - batch = { - 'decoder_input_tokens': jnp.array( - [[0, 3, 4, 5, 6, 0, 0], [0, 7, 8, 9, 0, 0, 0]] - ), - 'decoder_causal_attention': jnp.array( - [[1, 1, 1, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0, 0]] - ), - } - params = {} - - dummy_logits = jnp.expand_dims( - jnp.array([[-1e7, -1e7, 0, -1e7], [-1e7, -1e7, -1e7, 0]]), axis=1 - ) - - mock_transformer = mock.Mock() - mock_transformer.apply.return_value = (dummy_logits, {'cache': {}}) - mock_transformer.dtype = jnp.float32 - - def mock_init(self): - self.module = mock_transformer - self._output_vocabulary = mock.Mock(eos_id=1) - self._decode_fn = functools.partial(decoding.temperature_sample, topk=4) - self._inputs_bidirectional_attention = False - - with mock.patch.object( - models.MoeDecoderOnlyModel, '__init__', new=mock_init - ): - model = models.MoeDecoderOnlyModel() - - actual = model.predict_batch(params, batch) - expected = [[2, 2, 2, 2, 2, 0, 0], [3, 3, 3, 3, 3, 3, 0]] - np.testing.assert_array_equal(actual, expected) - - def test_moe_loss_fn(self): - batch = { - 'encoder_input_tokens': jnp.ones((2, 3)), - 'decoder_input_tokens': jnp.array([[1, 2, 1, 0], [0, 1, 0, 2]]), - 'decoder_target_tokens': jnp.array([[1, 2, 1, 0], [0, 1, 0, 2]]), - 'decoder_loss_weights': jnp.array([[1, 1, 1, 0], [0, 1, 0, 1]]), - } - logits = jnp.arange(0, 24).reshape((2, 4, 3)) - state = flax_core.freeze({ - 'intermediates': { - 'auxiliary_loss': (jnp.array([[0.2]]),), - 'router_z_loss': (jnp.array([[0.1]]),), - 'router_confidence': (jnp.array([[0.4]]),), - 'expert_usage': (jnp.array([[0.9]]),), - 'fraction_tokens_left_behind': (jnp.array([[0.1]]),), - } - }) - - loss, metrics = models._moe_loss_fn( - batch, - logits, - state, - label_smoothing=0.0, - z_loss=0.0, - loss_normalizing_factor=None, - aux_loss_factor=0.1, - router_z_loss_factor=0.01, - ) - - self.assertAlmostEqual(loss, 5.0590305) - self.assertContainsSubset( - { - 'experts/auxiliary_loss': AveragePerStep( - total=jnp.array(0.02), steps=1 - ), - 'experts/router_z_loss': AveragePerStep( - total=jnp.array(0.001), steps=1 - ), - 'experts/router_confidence': AveragePerStep( - total=jnp.array(0.5), steps=1 - ), - 'experts/expert_usage': AveragePerStep( - total=jnp.array(0.9), steps=1 - ), - 'experts/fraction_tokens_left_behind': AveragePerStep( - total=jnp.array(0.1), steps=1 - ), - 'accuracy': Accuracy(total=jnp.array(2.0), count=jnp.array(5)), - 'cross_ent_loss': AveragePerStep( - steps=1, total=jnp.array(5.0590305) - ), - 'loss': AveragePerStep(steps=1, total=jnp.array(5.0590305)), - 'timing/seqs_per_second': metrics_lib.TimeRate( - duration=None, numerator=2 - ), - 'timing/steps_per_second': metrics_lib.StepsPerTime( - duration=None, steps=1 - ), - }, - metrics, - ) - - -if __name__ == '__main__': - absltest.main() diff --git a/t5x-main/t5x/contrib/moe/partitioning.py b/t5x-main/t5x/contrib/moe/partitioning.py deleted file mode 100644 index c4cd61d4ca2dcc59f550d7ea287d9e4eed21cb46..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/moe/partitioning.py +++ /dev/null @@ -1,583 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Pjit partitioner with Mixture of Experts overrides.""" - -from typing import Any, Callable, Optional, Sequence, Tuple, Union - -from absl import logging -import cached_property -from flax import core as flax_core -import jax -from jax.experimental.pjit import pjit -from jax.sharding import Mesh -import numpy as np -from t5x import adafactor -from t5x import optimizers -from t5x import partitioning as base_partitioning -from t5x import train_state as train_state_lib -from t5x.contrib.moe import training_utils - -DataLayout = base_partitioning.DataLayout -FlaxOptimTrainState = train_state_lib.FlaxOptimTrainState -HardwareMesh = base_partitioning.HardwareMesh -InferenceState = train_state_lib.InferenceState -JaxDevice = jax.Device -LogicalAxisRules = base_partitioning.LogicalAxisRules -PartitionSpec = base_partitioning.PartitionSpec -Pytree = Any -TrainState = train_state_lib.TrainState - - -def get_cpu_mesh() -> Mesh: - """Trivial MoE mesh for CPU Testing.""" - base_cpu_mesh = base_partitioning.get_cpu_mesh() - # Add extra dimension for new 'expert' axis. - devices = np.expand_dims(base_cpu_mesh.devices, axis=-1) - return Mesh(devices, ['data', 'expert', 'model']) - - -def get_gpu_mesh() -> Mesh: - """Simple MoE mesh for GPUs.""" - base_gpu_mesh = base_partitioning.get_gpu_mesh(jax.local_device_count()) - # Move devices from the 'model' to the 'expert' axis. - devices = np.expand_dims(base_gpu_mesh.devices, axis=-1) - return Mesh(devices, ['data', 'expert', 'model']) - - -def default_moe_mesh( - num_expert_partitions: int, - num_partitions: Optional[int] = None, - model_parallel_submesh: Optional[HardwareMesh] = None, - backend: Optional[str] = None, -) -> Mesh: - """Construct default xmap/pjit mesh for MoE. - - Unlike the vanilla T5X mesh, this mesh has three resource axes: - - 'expert': 1D submesh with length that divides into `num_expert_partitions`, - - 'model': specified by the provided `model_parallel_submesh` shape, and - - 'data', which covers the rest of the mesh. - - Relative to the vanilla T5X mesh, the `expert` axis is constructed by - factoring along the 'data' axis length. - - Args: - num_expert_partitions: Upper bound for size of expert parallel submesh. This - must be <= the number of experts. Actual values depends on number of - available devices. - num_partitions: Specifies the size of the model parallel submesh to be - automatically selected for the current topology. See - `model_parallel_submesh` for details on how this submesh is used. Mutually - exclusive with `model_parallel_submesh`. - model_parallel_submesh: 4-tuple that specifies the `(x, y, z, c)` submesh - model-parallel device tile. See also t5x/partitioning.py for details. This - argument is mutually exclusive with `num_partitions`. - backend: Fetch devices from the pinned backend, if specified. This is useful - for explicitly specifying the devices other than relying on - jax_platform_name. - - Returns: - xmap/pjit 3D Mesh with 'data', 'expert' and 'model' mesh axes. - """ - # Base mesh has shape ('data', 'model'). - logging.info('For MoE, first construct vanilla T5X (data, model) mesh.') - base_default_mesh = base_partitioning.default_mesh( - num_partitions, model_parallel_submesh, backend - ) - data_axis_size, model_axis_size = base_default_mesh.devices.shape - - # Factor out the largest divisor of 'data' axis satisfying <= - # `num_expert_partitions`. - expert_axis_size = num_expert_partitions - while data_axis_size % expert_axis_size != 0: - expert_axis_size -= 1 - - # Reshape mesh to ('data', 'expert', 'model'). - devices = base_default_mesh.devices.reshape( - -1, expert_axis_size, model_axis_size - ) - global_mesh = Mesh(devices, ['data', 'expert', 'model']) - logging.info( - 'Overridden MoE global_mesh axes_names: %s', global_mesh.axis_names - ) - logging.info('Overridden MoE global_mesh devices: %s', global_mesh.devices) - logging.info( - 'Overridden MoE global_mesh shape: %s', global_mesh.devices.shape - ) - return global_mesh - - -class MoePjitPartitioner(base_partitioning.PjitPartitioner): - """Pjit partitioner with overrides for Mixture of Experts support. - - This MoE partitioner overrides the default partitioner to use the MoE friendly - ('data', 'expert', 'model') mesh. MoE params and state are partitioned along - the 'expert' axis. Data is partitioned along both of the 'data' AND 'expert' - axes. - - Additionally, when training with T5X's Adafactor optimizer, it handles an edge - case where the MoE optimizer state terms do NOT automatically inherit the - 'expert' axis annotation from the model params; see get_logical_axes(). - """ - - def __init__( - self, - num_expert_partitions: int, - num_partitions: Optional[int] = None, - model_parallel_submesh: Optional[HardwareMesh] = None, - params_on_devices: bool = True, - logical_axis_rules: Optional[LogicalAxisRules] = None, - state_filter_fn: Optional[Callable[[str], bool]] = None, - ): - """Configures the partitioner. - - TODO(jamesleethorp): Rename num_partitions -> num_model_partitions. - - Args: - num_expert_partitions: Specifies the upper bound for size of the expert - parallel submesh. This must be <= the number of experts. Actual value - depends on number of available devices. - num_partitions: Specifies the size of the model parallel submesh to be - automatically selected for the current topology. See - `model_parallel_submesh` for details on how this submesh is used. - Mutually exclusive with `model_parallel_submesh`. - model_parallel_submesh: 4-tuple that specifies the `(x, y, z, c)` submesh - model-parallel device tile -- an axis of accelerator parallelism - orthogonal to data parallelism. See t5x/partitioning.py for details. - This argument is mutually exclusive with `num_partitions`. - params_on_devices: Whether to keep the params on devices. If False, params - stay in the host memory. - logical_axis_rules: A priority-ordered sequence of KV tuples that maps - logical axis names to either `None` (not sharded), 'model' (to shard - across the model-parallel submesh), 'data' (to shard across the - data-parallel submesh), or 'expert' (for expert parallelism). - state_filter_fn: Function to identify which optimizer state axis rules - should be overridden to be sharded along the 'expert' axis. If None - (default), Adafactor expert sharding overrides are used. - """ - if logical_axis_rules is None: - logical_axis_rules = standard_logical_axis_rules() - - super().__init__( - num_partitions=num_partitions, - model_parallel_submesh=model_parallel_submesh, - params_on_devices=params_on_devices, - logical_axis_rules=logical_axis_rules, - ) - - self._num_expert_partitions = num_expert_partitions - self._state_filter_fn = state_filter_fn - - @property - def data_partition_spec(self) -> PartitionSpec: - """Returns MoE data partitioning spec. - - Data is sharded across the 'expert' and 'data' axes. - - Returns: - Mesh dependent partition spec. - """ - return PartitionSpec( - ('expert', 'data'), - ) - - @cached_property.cached_property - def mesh(self) -> Mesh: - """Overrides default T5X mesh with ('data', 'expert', 'model') mesh.""" - return default_moe_mesh( - self._num_expert_partitions, - self._num_partitions, - self._model_parallel_submesh, - self._backend, - ) - - def get_data_layout( - self, batch_size: Optional[int] = None, host_index: Optional[int] = None - ) -> DataLayout: - """Returns filled `DataLayout` based on the partitioned model layout. - - Overrides default data layout for MoE, where we treat 'data' and 'expert' - axes as "data" axes. - - Args: - batch_size: If set, indicates the requested batch size. If not set, the - batch size is inferred from the layout. - host_index: Indicates the host index to use for the calculations, if not - set - use JAX-provided one. Should be in [0, num_hosts) interval and the - order should match the order of corresponding CPU devices in - `jax.devices()`. - - Returns: - Filled `DataLayout` structure. - """ - if host_index is not None: - raise NotImplementedError('Explicit host_index is not yet implemented.') - - num_data_partitions = self._local_chunker.global_mesh.shape['data'] - num_expert_partitions = self._local_chunker.global_mesh.shape['expert'] - - data_mesh_size = num_data_partitions * num_expert_partitions - batch_size = batch_size or data_mesh_size - if batch_size % data_mesh_size: - raise ValueError( - f'Batch size ({batch_size}) must be divisible by entire data mesh ' - f'size ({data_mesh_size}). Note that for MoE, the data mesh spans ' - 'both the "expert" and "data" virtual mesh axes.' - ) - - num_shards = ( - self._local_chunker.num_chunks['data'] - * self._local_chunker.num_chunks['expert'] - ) - if batch_size % num_shards: - raise ValueError( - f'Batch size ({batch_size}) must be divisible by total number of ' - f'shards ({num_shards}) across "data" and "expert" mesh axes.' - ) - - # Partition the batch over both of the 'expert' and 'data' axes. - global_array_shape = ( - num_expert_partitions, - batch_size // num_expert_partitions, - ) - replica_id = self._local_chunker.get_local_chunk_info( - global_array_shape, ('expert', 'data') - ).replica_id - - return DataLayout( - batch_size=batch_size, - shard_id=( - self._local_chunker.chunk_ids['data'] - + self._local_chunker.chunk_ids['expert'] - * self._local_chunker.num_chunks['data'] - ), - num_shards=num_shards, - is_first_host_in_replica_set=(replica_id == 0), - ) - - def get_logical_axes( # pytype: disable=signature-mismatch # overriding-parameter-type-checks - self, train_state: Union[FlaxOptimTrainState, InferenceState] - ) -> Union[FlaxOptimTrainState, InferenceState]: - """Returns a copy of TrainState with Optional[AxisNames] as leaves. - - Overrides the default logical axes by prepending the 'expert' axis to any - MoE optimizer state terms (identified by self._state_filter_fn); this is - useful for T5X's Adafactor optimizer, which does not propagate param - annotations to the optimizer state when the optimizers is factored. - - Args: - train_state: Object holding all relevant training of inference state. - - Returns: - State object matching structure of input train_state but with axis names - as leaves. - """ - logical_axes = train_state.as_logical_axes() - - if isinstance(logical_axes, InferenceState): - # InferenceState does not contain any optimizer state, so we skip all - # expert partitioning overrides. - return logical_axes - else: - train_state: FlaxOptimTrainState - - state_filter_fn = self._state_filter_fn or _infer_state_filter_fn( - train_state - ) - if state_filter_fn is None: - # No state updates required. - return logical_axes - - prepend_expert = ( # pylint: disable=g-long-ternary - lambda x: PartitionSpec(*('expert',) + x) # pylint: disable=g-long-lambda disable=g-long-ternary - if x - else PartitionSpec( - 'expert', - ) - ) - optimizer_axes = logical_axes._optimizer # pylint: disable=protected-access - state_dict = flax_core.unfreeze(optimizer_axes.state_dict()) - state_dict['state']['param_states'] = training_utils.tree_map_with_names( - prepend_expert, state_dict['state']['param_states'], state_filter_fn - ) - if logical_axes.flax_mutables: - state_dict['flax_mutables'] = logical_axes.flax_mutables - - return train_state.restore_state(state_dict) - - def partition( - self, - fn: Callable, # pylint: disable=g-bare-generic - in_axis_resources: Pytree, - out_axis_resources: Pytree, - static_argnums: Union[int, Sequence[int]] = (), - donate_argnums: Union[int, Sequence[int]] = (), - ) -> base_partitioning.PjittedFnWithContext: - """Partitions the computation using pjit. - - Overrides the default pjit partitioning to ensure that data is sharded along - both of the 'data' and 'expert' axes. - - Args: - fn: Function to partition. - in_axis_resources: Pytree of structure matching that of arguments to `fn`, - with all actual arguments replaced by resource assignment - specifications. - out_axis_resources: Like `in_axis_resources`, but specifies resource - assignment for function outputs. - static_argnums: Specifies which positional arguments to treat as static - (compile-time constant) in the partitioned function. - donate_argnums: Specifies which argument buffers are "donated" to the - computation. - - Returns: - A partitioned version of the input function. - """ - # Override the partition specs to use 'data' AND 'expert' axes for data - # parallelism. - in_axis_resources = override_partition_specs(in_axis_resources) - out_axis_resources = override_partition_specs(out_axis_resources) - - pjitted = pjit( - fn, - in_shardings=in_axis_resources, - out_shardings=out_axis_resources, - static_argnums=static_argnums, - donate_argnums=donate_argnums, - ) - - return base_partitioning.PjittedFnWithContext( - pjitted, self.mesh, self._logical_axis_rules - ) - - -def standard_logical_axis_rules( - activation_partitioning_dims: int = 2, - parameter_partitioning_dims: int = 1, - additional_rules: Optional[LogicalAxisRules] = None, -): - """Returns partitioning rules for MoE models. - - MoE params and state are partitioned along the 'expert' axis. Data is - partitioned along both of the 'data' AND 'expert' axes. - - The partitioning rules vary based on whether the expert and data axes need to - be decoupled; see also MoePjitPartitioner for details of when expert and data - axes need to be decoupled. - - Defaults to 2-D activation sharding for efficiency; see - https://arxiv.org/abs/2211.05102. - - Buyer beware: 2D parameter sharding (`parameter_partitioning_dims=2`) is - technically supported but untested. - - Args: - activation_partitioning_dims: Enables 2-D activation sharding when set to 2. - parameter_partitioning_dims: Enables 2-D parameter sharding when set to 2. - additional_rules: Additional rules (a sequence of tuples) that will be - appended to the standard rules. - - Returns: - Sequence of logical axis rules. - """ - _ = base_partitioning.global_mesh_defined() - - if parameter_partitioning_dims == 2: - raise logging.warning( - '2D parameter sharding (`parameter_partitioning_dims=2`) is supported ' - 'but untested for MoE.' - ) - - default_rules = base_partitioning.standard_logical_axis_rules( - activation_partitioning_dims, parameter_partitioning_dims - ) - moe_rules = [ - ('expert', 'expert'), # Shard experts along the expert axis - ('expert_mlp', 'model'), # Expert MLPs partitioned along model axis - ('expert_replicas', 'data'), # Experts replicated along "pure" data axis - ('unmodeled', None), # Replicated weights - ] - standard_rules = list(default_rules) + moe_rules - if additional_rules: - standard_rules.extend(additional_rules) - - overridden_rules = [] - for logical_axis, mesh_axis in standard_rules: - if logical_axis == 'batch': - # Data is sharded across both 'data' and 'expert axes. - overridden_mesh_axis = ('expert', 'data') - else: - overridden_mesh_axis = mesh_axis - overridden_rules.append((logical_axis, overridden_mesh_axis)) - - return overridden_rules - - -def compute_num_model_partitions( - num_model_partitions: Optional[int], - model_parallel_submesh: Optional[HardwareMesh], -) -> int: - """Returns number of model partitions. - - Args: - num_model_partitions: Specifies the size of the model parallel submesh. - model_parallel_submesh: 4-tuple that specifies the `(x, y, z, c)` submesh - model-parallel device tile - - Returns: - Size of model parallel submesh. - - Raises: - ValueError if neither num_model_partitions nor model_parallel_submesh are - specified, or if they are inconsistent. - """ - if num_model_partitions is None and model_parallel_submesh is None: - raise ValueError( - 'At least one of num_model_partitions and ' - 'model_parallel_submesh must be specified.' - ) - - if num_model_partitions is not None: - if model_parallel_submesh is not None and num_model_partitions != np.prod( - model_parallel_submesh - ): - raise ValueError( - 'num_model_partitions and model_parallel_submesh are inconsistent. ' - 'Received: %s and %s' % (num_model_partitions, model_parallel_submesh) - ) - return num_model_partitions - else: - return np.prod(model_parallel_submesh) - - -def override_partition_specs(resources: Pytree): - """Override raw axis resources so data is sharded over 'data' & 'expert' axes. - - Here, we only override any raw partition specs that are hardcoded in T5X - libraries: - PartitionSpec('data',) -> PartitionSpec(('expert', 'data'),) - - NOTE: We do not (and there is no need) to override any params or optimizer - state (which appear as large Pytrees) as these will inherit the correct specs - from the logical axis rules; see also standard_logical_axis_rules(). - - Args: - resources: Axis resource assignment specifications. - - Returns: - Axis resources with partition specs overridden to use 'model' as secondary - 'data' axis. - """ - - def _maybe_override_spec(axis_resource: Pytree): - """Overrides raw "data" partition specs; leaves others unchanged.""" - if axis_resource == PartitionSpec( - 'data', - ): - # Shard all data across 'data' and 'expert' axes. - return PartitionSpec( - ('expert', 'data'), - ) - else: - return axis_resource - - if isinstance(resources, PartitionSpec): - return _maybe_override_spec(resources) - elif isinstance(resources, Sequence): - overridden_resources = [] - for resource in resources: - overridden_resources.append(_maybe_override_spec(resource)) - return tuple(overridden_resources) - else: - return resources - - -def _infer_state_filter_fn( - train_state: FlaxOptimTrainState, -) -> Optional[Callable[[str], bool]]: - """Infers relevant regex matching sharded expert model state for train state. - - The model state generally inherits the correct partitioning specs from the - model parameters. In such cases, no state_filter_fn is required. However, - T5X's custom Adafactor optimizer, when factored, requires overrides to the - `v_col` and `v_row` kernel terms; see - https://github.com/google-research/t5x/blob/main/t5x/adafactor.py#L591. For - those cases, we use the state_filter_fn to identify the factored kernel terms - that need to be partitioned along the expert axis. - - Args: - train_state: Object holding optimizer and optimizer state (parameters). - - Returns: - Function to identify which model state is sharded along 'expert' axis. - - Raises: - ValueError if optimizer (on train state) is not a recognized optimizer type. - """ - optimizer = train_state._optimizer # pylint: disable=protected-access - opt_def = optimizer.optimizer_def - - if isinstance(opt_def, optimizers.MultiOptimizer): - if not opt_def.sub_optimizers: - # No suboptimizers, so no state updates are required. - return None - - all_same_type = all( - type(opt) is type(opt_def.sub_optimizers[0]) - for opt in opt_def.sub_optimizers - ) - if not all_same_type: - raise ValueError( - 'optimizers.MultiOptimizer is only supported in cases ' - 'where all suboptimizers are of the same type.' - ) - - if isinstance(opt_def.sub_optimizers[0], adafactor.Adafactor): - all_same_factoring = all( - opt.hyper_params.factored - == opt_def.sub_optimizers[0].hyper_params.factored - for opt in opt_def.sub_optimizers - ) - if not all_same_factoring: - raise ValueError( - 'If using adafactor.Adafactor as the suboptimizer in ' - 'optimizers.MultiOptimizer, all suboptimizers must be either ' - 'factored or unfactored (cannot use mixed factoring).' - ) - - # Use first suboptimizer as representative. - opt_def = opt_def.sub_optimizers[0] - - if isinstance(opt_def, optimizers.OptaxWrapper): - # T5X wrapped optax optimizers inherit the correct specs, so no state - # updates will be required. - return None - - if not isinstance(opt_def, adafactor.Adafactor): - raise ValueError( - 'Unrecognized optimizer type. Expecting ' - 'optimizers.OptaxWrapper or adafactor.Adafactor. ' - f'Received: {opt_def}' - ) - - if opt_def.hyper_params.factored: - # Factored kernel terms (`v_col` and `v_row`) need to be identified for - # expert sharding. - return training_utils.match_fn(r'.*expert.*/kernel/v_.*') - else: - # Non-factored kernel terms (`v`) inherit the correct specs, so no state - # updates will be required. - return None - - diff --git a/t5x-main/t5x/contrib/moe/partitioning_test.py b/t5x-main/t5x/contrib/moe/partitioning_test.py deleted file mode 100644 index 05c7b73be72991b02bccf86a7af2237399407de0..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/moe/partitioning_test.py +++ /dev/null @@ -1,632 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for partitioning.""" - -from typing import Any -import unittest - -from absl.testing import absltest -import flax -from flax import core as flax_core -from flax.linen import partitioning as flax_partitioning -import jax -import numpy as np -import optax -from t5x import adafactor -from t5x import optimizers -from t5x import partitioning as base_partitioning -from t5x import test_utils -from t5x import train_state as train_state_lib -from t5x.contrib.moe import partitioning as moe_partitioning -from t5x.contrib.moe import training_utils - -mock = absltest.mock - -AxisMetadata = flax_partitioning.AxisMetadata -DataLayout = moe_partitioning.DataLayout -FlaxOptimTrainState = train_state_lib.FlaxOptimTrainState -FrozenDict = flax_core.frozen_dict.FrozenDict -FrozenVariableDict = flax_core.scope.FrozenVariableDict -InferenceState = train_state_lib.InferenceState -PartitionSpec = moe_partitioning.PartitionSpec -PRNGKey = Any - - -def create_model_variables( - add_flax_mutables: bool = False, -) -> FrozenVariableDict: - """Creates simple model variables.""" - variables = { - 'params': { - 'logits_dense': np.ones((16, 16), np.float32), - 'mlp': {'wo': {'kernel': np.ones((32, 16), np.float32)}}, - }, - 'params_axes': { - 'logits_dense_axes': AxisMetadata(names=('vocab', 'embed')), - 'mlp': {'wo': {'kernel_axes': AxisMetadata(names=('embed', 'mlp'))}}, - }, - } - if add_flax_mutables: - variables.update({ - 'other': { - 'variables': np.ones((16, 16), np.float32), - }, - 'other_axes': { - 'variables_axes': AxisMetadata(names=('vocab', 'embed')), - }, - }) - - return flax_core.freeze(variables) - - -def create_train_state(add_flax_mutables: bool = False) -> FlaxOptimTrainState: - """Creates simple Adam optimizer train state.""" - optimizer_def = optimizers.adamw(learning_rate=1e-4) - return FlaxOptimTrainState.create( - optimizer_def, create_model_variables(add_flax_mutables) - ) - - -def create_adafactor_train_state(factored: bool = True) -> FlaxOptimTrainState: - """Creates MultiOptimizer train state.""" - optimizer_def = adafactor.Adafactor(learning_rate=0.1, factored=factored) - return FlaxOptimTrainState.create(optimizer_def, create_model_variables()) - - -def create_multioptimizer_train_state( - factored: bool = True, -) -> FlaxOptimTrainState: - """Creates MultiOptimizer train state.""" - - def _is_mlp(path): - return 'mlp' in path - - mlp_vars = flax.traverse_util.ModelParamTraversal( - lambda path, _: not _is_mlp(path) - ) - non_mlp_vars = flax.traverse_util.ModelParamTraversal( - lambda path, _: _is_mlp(path) - ) - scaled_opt = adafactor.Adafactor(learning_rate=0.1, factored=factored) - unscaled_opt = adafactor.Adafactor( - learning_rate=0.1, multiply_by_parameter_scale=False, factored=factored - ) - - optimizer_def = optimizers.MultiOptimizer( - ((mlp_vars, scaled_opt), (non_mlp_vars, unscaled_opt)) - ) - - return FlaxOptimTrainState.create(optimizer_def, create_model_variables()) - - -class PartitioningTest(absltest.TestCase): - - @unittest.skipIf(jax.__version_info__ < (0, 4, 5), 'Test requires jax 0.4.5') - @mock.patch('jax.local_devices') - @mock.patch('jax.devices') - @mock.patch(f'{jax.process_index.__module__}.process_index') - def test_default_mesh(self, process_index_fn, devices_fn, local_devices_fn): - # Mesh with 8 devices. - devices = test_utils.make_devices(2, 2, 1, 2, kind='TPU v3') - devices_fn.return_value = devices - local_devices_fn.return_value = [d for d in devices if d.process_index == 0] - process_index_fn.return_value = 0 - - with self.subTest(name='more_experts_than_devices'): - mesh = moe_partitioning.default_moe_mesh( - num_expert_partitions=16, num_partitions=1 - ) - self.assertEqual(mesh.devices.shape, (1, 8, 1)) - self.assertEqual(mesh.axis_names, ('data', 'expert', 'model')) - - with self.subTest(name='equal_experts_and_devices'): - mesh = moe_partitioning.default_moe_mesh( - num_expert_partitions=8, num_partitions=1 - ) - self.assertEqual(mesh.devices.shape, (1, 8, 1)) - self.assertEqual(mesh.axis_names, ('data', 'expert', 'model')) - - with self.subTest(name='fewer_experts_than_devices'): - mesh = moe_partitioning.default_moe_mesh( - num_expert_partitions=4, num_partitions=1 - ) - self.assertEqual(mesh.devices.shape, (2, 4, 1)) - self.assertEqual(mesh.axis_names, ('data', 'expert', 'model')) - - with self.subTest(name='nontrivial_model_partitions'): - mesh = moe_partitioning.default_moe_mesh( - num_expert_partitions=8, num_partitions=4 - ) - self.assertEqual(mesh.devices.shape, (1, 2, 4)) - self.assertEqual(mesh.axis_names, ('data', 'expert', 'model')) - - with self.subTest(name='specified_model_parallel_submesh'): - mesh = moe_partitioning.default_moe_mesh( - num_expert_partitions=8, model_parallel_submesh=(1, 1, 1, 2) - ) - self.assertEqual(mesh.devices.shape, (1, 4, 2)) - self.assertEqual(mesh.axis_names, ('data', 'expert', 'model')) - - def test_gpu_mesh(self): - mesh = moe_partitioning.get_gpu_mesh() - self.assertEqual(mesh.devices.shape, (1, jax.device_count(), 1)) - self.assertEqual(mesh.axis_names, ('data', 'expert', 'model')) - - def test_cpu_mesh(self): - mesh = moe_partitioning.get_cpu_mesh() - self.assertEqual(mesh.devices.shape, (1, jax.device_count(), 1)) - self.assertEqual(mesh.axis_names, ('data', 'expert', 'model')) - - @unittest.skipIf(jax.__version_info__ < (0, 4, 5), 'Test requires jax 0.4.5') - @mock.patch('jax.local_devices') - @mock.patch('jax.devices') - @mock.patch(f'{jax.process_index.__module__}.process_index') - def test_local_chunker_moe_usage( - self, process_index_fn, devices_fn, local_devices_fn - ): - # The MoE partitioning library uses a 2D "data" mesh spanning ('expert', - # 'data') axes, so we reshape the batch across this 2D "data" mesh when - # computing replica ids from local chunk info. In this test, we check that - # the replica ids constructed in this manner are equivalent to the default - # replica id (over a single 'data' mesh axis). - - # Mesh with 32 devices. - devices = test_utils.make_devices(2, 2, 1, 2, kind='TPU v3') - devices_fn.return_value = devices - local_devices_fn.return_value = [d for d in devices if d.process_index == 0] - process_index_fn.return_value = 0 - - num_expert_partitions = 8 - moe_mesh = moe_partitioning.default_moe_mesh( - num_expert_partitions=num_expert_partitions, num_partitions=2 - ) - moe_chunker = base_partitioning.LocalChunker(moe_mesh) - - base_mesh = base_partitioning.default_mesh(num_partitions=2) - base_chunker = base_partitioning.LocalChunker(base_mesh) - - for batch_size in [8, 16, 32, 64]: - moe_global_array_shape = ( - batch_size // num_expert_partitions, - num_expert_partitions, - ) - moe_replica_id = moe_chunker.get_local_chunk_info( - moe_global_array_shape, ('expert', 'data') - ).replica_id - base_global_array_shape = (batch_size,) - base_replica_id = base_chunker.get_local_chunk_info( - base_global_array_shape, ('data',) - ).replica_id - self.assertEqual(moe_replica_id, base_replica_id) - - @unittest.skipIf(jax.__version_info__ < (0, 4, 5), 'Test requires jax 0.4.5') - @mock.patch('jax.local_devices') - @mock.patch('jax.devices') - @mock.patch(f'{jax.process_index.__module__}.process_index') - def test_local_chunker_data_layout( - self, process_index_fn, devices_fn, local_devices_fn - ): - # Mesh with 32 devices. - devices = test_utils.make_devices(4, 4, 1, 2, kind='TPU v3') - devices_fn.return_value = devices - local_devices_fn.return_value = [d for d in devices if d.process_index == 0] - - for process_index, shard_id in zip([0, 1, 2, 3], [0, 2, 1, 3]): - process_index_fn.return_value = process_index - partitioner = moe_partitioning.MoePjitPartitioner( - num_expert_partitions=8, num_partitions=1 - ) - self.assertEqual( - partitioner.get_data_layout(batch_size=32), - DataLayout( - batch_size=32, - shard_id=shard_id, - num_shards=4, - is_first_host_in_replica_set=True, - ), - ) - - def test_logical_axes_for_moe_partitioner_no_overrides(self): - partitioner = moe_partitioning.MoePjitPartitioner( - num_expert_partitions=8, - num_partitions=1, - state_filter_fn=training_utils.match_fn(r'no_state_matching'), - ) - - train_state = create_train_state() - logical_axes = partitioner.get_logical_axes(train_state) - - # No updates to state. - self.assertEqual( - logical_axes.param_states, - ( - optax.ScaleByAdamState( - count=None, - mu=FrozenDict({ - 'logits_dense': PartitionSpec('vocab', 'embed'), - 'mlp': { - 'wo': { - 'kernel': PartitionSpec('embed', 'mlp'), - }, - }, - }), - nu=FrozenDict({ - 'logits_dense': PartitionSpec('vocab', 'embed'), - 'mlp': { - 'wo': { - 'kernel': PartitionSpec('embed', 'mlp'), - }, - }, - }), - ), - optax.EmptyState(), - optax.EmptyState(), - ), - ) - - # Target (params) should be unchanged. - self.assertEqual( - logical_axes.params, - FrozenDict({ - 'logits_dense': PartitionSpec('vocab', 'embed'), - 'mlp': { - 'wo': { - 'kernel': PartitionSpec('embed', 'mlp'), - }, - }, - }), - ) - - def test_logical_axes_for_moe_partitioner_with_overrides(self): - partitioner = moe_partitioning.MoePjitPartitioner( - num_expert_partitions=8, - num_partitions=1, - state_filter_fn=training_utils.match_fn(r'.*mlp.*'), - ) - - train_state = create_train_state() - logical_axes = partitioner.get_logical_axes(train_state) - - # 'mlp' params should be prepended with 'expert' spec because - # state_filter_fn matches '.*mlp.*'. - self.assertEqual( - logical_axes.param_states, - ( - optax.ScaleByAdamState( - count=None, - mu=FrozenDict({ - 'logits_dense': PartitionSpec('vocab', 'embed'), - 'mlp': { - 'wo': { - 'kernel': PartitionSpec('expert', 'embed', 'mlp'), - }, - }, - }), - nu=FrozenDict({ - 'logits_dense': PartitionSpec('vocab', 'embed'), - 'mlp': { - 'wo': { - 'kernel': PartitionSpec('expert', 'embed', 'mlp'), - }, - }, - }), - ), - optax.EmptyState(), - optax.EmptyState(), - ), - ) - - # Target (params) should be unchanged. - self.assertEqual( - logical_axes.params, - FrozenDict({ - 'logits_dense': PartitionSpec('vocab', 'embed'), - 'mlp': { - 'wo': { - 'kernel': PartitionSpec('embed', 'mlp'), - }, - }, - }), - ) - - def test_logical_axes_for_moe_partitioner_with_flax_mutables(self): - partitioner = moe_partitioning.MoePjitPartitioner( - num_expert_partitions=8, - num_partitions=1, - state_filter_fn=training_utils.match_fn(r'no_state_matching'), - ) - - train_state = create_train_state(add_flax_mutables=True) - logical_axes = partitioner.get_logical_axes(train_state) - - # No updates to state. - self.assertEqual( - logical_axes.param_states, - ( - optax.ScaleByAdamState( - count=None, - mu=FrozenDict({ - 'logits_dense': PartitionSpec('vocab', 'embed'), - 'mlp': { - 'wo': { - 'kernel': PartitionSpec('embed', 'mlp'), - }, - }, - }), - nu=FrozenDict({ - 'logits_dense': PartitionSpec('vocab', 'embed'), - 'mlp': { - 'wo': { - 'kernel': PartitionSpec('embed', 'mlp'), - }, - }, - }), - ), - optax.EmptyState(), - optax.EmptyState(), - ), - ) - - # Target (params) should be unchanged. - self.assertEqual( - logical_axes.params, - FrozenDict({ - 'logits_dense': PartitionSpec('vocab', 'embed'), - 'mlp': { - 'wo': { - 'kernel': PartitionSpec('embed', 'mlp'), - }, - }, - }), - ) - - # Should preserve flax_mutables. - self.assertEqual( - logical_axes.flax_mutables, - FrozenDict({ - 'other': { - 'variables': PartitionSpec('vocab', 'embed'), - }, - }), - ) - - def test_inference_state_logical_axes(self): - partitioner = moe_partitioning.MoePjitPartitioner( - num_expert_partitions=8, num_partitions=1 - ) - - model_variables = flax_core.freeze({ - 'params': {'dense': {'bias': np.zeros(4), 'kernel': np.zeros((2, 4))}}, - 'params_axes': { - 'dense': { - 'bias_axes': AxisMetadata(names=('embed',)), - 'kernel_axes': AxisMetadata(names=('vocab', 'embed')), - } - }, - }) - train_state = InferenceState.create(model_variables) - logical_axes = partitioner.get_logical_axes(train_state) - - # No expert axis overrides to InferenceState. Partition specs should match - # input axis metadata. - self.assertEqual( - logical_axes, - InferenceState( - step=None, - params=flax_core.FrozenDict({ - 'dense': { - 'bias': PartitionSpec( - 'embed', - ), - 'kernel': PartitionSpec('vocab', 'embed'), - }, - }), - ), - ) - - def test_infer_state_function(self): - with self.subTest(name='optax'): - optax_train_state = create_train_state() - self.assertIsNone( - moe_partitioning._infer_state_filter_fn(optax_train_state) - ) - - with self.subTest(name='factored_adafactor'): - adafactor_train_state = create_adafactor_train_state(factored=True) - match_fn = moe_partitioning._infer_state_filter_fn(adafactor_train_state) - self.assertTrue(match_fn('expert/kernel/v_col')) - self.assertTrue(match_fn('expert/kernel/v_row')) - self.assertFalse(match_fn('expert/kernel/m')) - self.assertFalse(match_fn('kernel/v_col')) - - with self.subTest(name='unfactored_adafactor'): - adafactor_train_state = create_adafactor_train_state(factored=False) - self.assertIsNone( - moe_partitioning._infer_state_filter_fn(adafactor_train_state) - ) - - with self.subTest(name='factored_adafactor_multi_optimizer'): - multi_opt_train_state = create_multioptimizer_train_state(factored=True) - match_fn = moe_partitioning._infer_state_filter_fn(multi_opt_train_state) - self.assertTrue(match_fn('expert/kernel/v_col')) - self.assertTrue(match_fn('expert/kernel/v_row')) - self.assertFalse(match_fn('expert/kernel/m')) - self.assertFalse(match_fn('kernel/v_col')) - - with self.subTest(name='unfactored_adafactor_multi_optimizer'): - multi_opt_train_state = create_multioptimizer_train_state(factored=False) - self.assertIsNone( - moe_partitioning._infer_state_filter_fn(multi_opt_train_state) - ) - - with self.subTest(name='mixed_factoring_adafactor_multi_optimizer'): - true_vars = flax.traverse_util.ModelParamTraversal(lambda p, _: True) - false_vars = flax.traverse_util.ModelParamTraversal(lambda p, _: False) - factored_opt = adafactor.Adafactor(learning_rate=0.1, factored=True) - unfactored_opt = adafactor.Adafactor(learning_rate=1.0, factored=False) - optimizer_def = optimizers.MultiOptimizer( - ((true_vars, factored_opt), (false_vars, unfactored_opt)) - ) - multi_opt_train_state = FlaxOptimTrainState.create( - optimizer_def, create_model_variables() - ) - - with self.assertRaisesRegex( - ValueError, 'all suboptimizers must be either factored or unfactored' - ): - _ = moe_partitioning._infer_state_filter_fn(multi_opt_train_state) - - def test_logical_axis_rules(self): - self.assertEqual( - moe_partitioning.standard_logical_axis_rules( - additional_rules=[('additional', 'model'), ('expert_magic', 'data')] - ), - [ - ('batch', ('expert', 'data')), # Shard batch over entire mesh - # No sharding of weights over model axis. - ('vocab', 'model'), - ('mlp', 'model'), - ('heads', 'model'), - ('kv', None), - ('joined_kv', 'model'), - ('embed', 'model'), # Default is 2-D sharding - ('relpos_buckets', None), - ('abspos_buckets', None), - ('length', None), - ('layers', None), - ('stack', None), - ('mlp_activations', None), - ('expert', 'expert'), # Shard experts along expert axis - ('expert_mlp', 'model'), - # Experts replicated along "pure" data axis - ('expert_replicas', 'data'), - ('unmodeled', None), - ('additional', 'model'), - ('expert_magic', 'data'), - ], - ) - - def test_data_partition_spec(self): - partitioner = moe_partitioning.MoePjitPartitioner( - num_expert_partitions=2, num_partitions=1 - ) - self.assertEqual( - partitioner.data_partition_spec, - PartitionSpec( - ('expert', 'data'), - ), - ) - - def test_axis_resource_overrides(self): - with self.subTest(name='sequence_of_resources'): - input_resources = ( - PartitionSpec('data'), - PartitionSpec('model'), - PartitionSpec('expert'), - None, - PartitionSpec('unrecognized'), - ) - # 'data' -> ('expert', 'data'). - self.assertEqual( - moe_partitioning.override_partition_specs(input_resources), - ( - PartitionSpec( - ('expert', 'data'), - ), - PartitionSpec('model'), - PartitionSpec('expert'), - None, - PartitionSpec( - 'unrecognized', - ), - ), - ) - - with self.subTest(name='single_resource'): - # 'data' -> ('expert', 'data'). - self.assertEqual( - moe_partitioning.override_partition_specs( - PartitionSpec( - 'data', - ) - ), - PartitionSpec( - ('expert', 'data'), - ), - ) - - with self.subTest(name='no_override'): - # 'data' -> ('expert', 'data'). - self.assertEqual( - moe_partitioning.override_partition_specs( - PartitionSpec(('expert', 'data')) - ), - PartitionSpec( - ('expert', 'data'), - ), - ) - - with self.subTest(name='no_resource'): - self.assertIsNone(moe_partitioning.override_partition_specs(None)) - - def test_compute_num_model_partitions(self): - with self.subTest(name='no_model_parallel_submesh'): - self.assertEqual( - moe_partitioning.compute_num_model_partitions( - num_model_partitions=2, model_parallel_submesh=None - ), - 2, - ) - - with self.subTest(name='no_model_partitions'): - self.assertEqual( - moe_partitioning.compute_num_model_partitions( - num_model_partitions=None, model_parallel_submesh=(1, 2, 1, 2) - ), - 4, - ) - - with self.subTest(name='partitions_and_submesh'): - self.assertEqual( - moe_partitioning.compute_num_model_partitions( - num_model_partitions=4, model_parallel_submesh=(1, 2, 1, 2) - ), - 4, - ) - - with self.subTest(name='inconsistent_partitions'): - with self.assertRaisesRegex( - ValueError, - 'num_model_partitions and model_parallel_submesh are inconsistent.', - ): - _ = moe_partitioning.compute_num_model_partitions( - num_model_partitions=1, model_parallel_submesh=(1, 2, 1, 2) - ) - - with self.subTest(name='no_submesh_or_partitions'): - with self.assertRaisesRegex( - ValueError, - ( - 'At least one of num_model_partitions and model_parallel_submesh' - ' must be specified.' - ), - ): - _ = moe_partitioning.compute_num_model_partitions( - num_model_partitions=None, model_parallel_submesh=None - ) - - -if __name__ == '__main__': - absltest.main() diff --git a/t5x-main/t5x/contrib/moe/trainer.py b/t5x-main/t5x/contrib/moe/trainer.py deleted file mode 100644 index 1a1a02241b67eb4ba5c2bcb0e0bf24b90ca52dd3..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/moe/trainer.py +++ /dev/null @@ -1,149 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Trainer with Mixture of Experts support.""" - -from typing import Any, Callable, Optional, Sequence, TYPE_CHECKING - -import cached_property -from t5x import models -from t5x import train_state as train_state_lib -from t5x import trainer -from t5x.contrib.moe import partitioning -from t5x.contrib.moe import training_utils - -BatchType = trainer.BatchType -LearningRateCallable = trainer.LearningRateCallable -MetricMapType = trainer.MetricMapType -PartitionSpec = partitioning.PartitionSpec -PartitionedTrainCallable = trainer.PartitionedTrainCallable -Rng = trainer.Rng - -if TYPE_CHECKING: # See b/163639353 - cached_property = property # pylint: disable=invalid-name -else: - cached_property = cached_property.cached_property - - -class MoeTrainer(trainer.Trainer): - """T5X trainer with overrides for Mixture of Experts support.""" - - def __init__( - self, - model: models.BaseModel, - train_state: train_state_lib.TrainState, - partitioner: partitioning.MoePjitPartitioner, - eval_names: Sequence[str], - summary_dir: Optional[str], - train_state_axes: Any, - rng: Rng, - learning_rate_fn: LearningRateCallable, - num_microbatches: Optional[int], - num_expert_partitions: int, - sharded_match_fn: Optional[ - Callable[[str], bool] - ] = training_utils.match_fn(r'.*expert.*'), - weight_metrics_computer: Optional[trainer.WeightMetricsComputer] = None, - ): - """Trainer constructor. - - Args: - model: the instantiation of `BaseModel` to train. - train_state: a train state with parameters and optimizer state. - partitioner: the partitioner to use. - eval_names: names of evaluation datasets, which must match the keys of the - mapping passed to `eval`. - summary_dir: optional directory to write TensorBoard metrics to. - train_state_axes: partitioning info for the optimizer to be used. - rng: jax PRNGKey seed for random operations, to be combined with step - number for a deterministic RNG. - learning_rate_fn: returns the learning rate given the current step. - num_microbatches: the number of microbatches to use, or None for direct - training. - num_expert_partitions: Size of expert parallel submesh. Used to scale - sharded parameter gradients. - sharded_match_fn: Filter function for distinguishing sharded (MoE) - parameters from replicated parameters. Used to identify the sharded - parameter gradients that need to be rescaled under pjit training. - weight_metrics_computer: A WeightMetricsComputer instance, or None, to - decide what metrics, if any, to log about weights and weight updates - during training. - """ - super().__init__( - model=model, - train_state=train_state, - partitioner=partitioner, - eval_names=eval_names, - summary_dir=summary_dir, - train_state_axes=train_state_axes, - rng=rng, - learning_rate_fn=learning_rate_fn, - num_microbatches=num_microbatches, - weight_metrics_computer=weight_metrics_computer, - ) - - self._num_expert_partitions = num_expert_partitions - self._sharded_match_fn = sharded_match_fn - self.data_partition_spec = partitioner.data_partition_spec - - @cached_property - def _partitioned_train_step(self) -> PartitionedTrainCallable: - """Same as a regular T5X train step, but scales expert parameter gradients. - - We must scale expert parameter gradients by the number of experts to account - for pjit's implicit averaging over partitioned parameter gradients. - - Returns: - Partitioned train step function. - """ - - def train_with_lr( - train_state: train_state_lib.TrainState, batch: BatchType - ): - grad_accum, metrics, flax_mutables = ( - trainer.accumulate_grads_microbatched( - self._model, - train_state, - batch, - self._get_step_rng(train_state.step), # pytype: disable=wrong-arg-types # jax-ndarray - self._num_microbatches, - data_partition_spec=self.data_partition_spec, - ) - ) - - # Only difference between this train step and regular T5X train step: - scaled_grads = training_utils.scale_sharded_grads( - grad_accum, - self._sharded_match_fn, - scale_factor=self._num_expert_partitions, - ) - - new_train_state, metrics = trainer.apply_grads( - train_state, - scaled_grads, - metrics, - self._learning_rate_fn(train_state.step), - self._weight_metrics_computer, - other_state_variables={'flax_mutables': flax_mutables} - if flax_mutables - else None, - ) - return new_train_state, metrics - - return self._partitioner.partition( - train_with_lr, - in_axis_resources=(self._train_state_axes, self.data_partition_spec), - out_axis_resources=(self._train_state_axes, None), - donate_argnums=(0,), - ) diff --git a/t5x-main/t5x/contrib/moe/trainer_test.py b/t5x-main/t5x/contrib/moe/trainer_test.py deleted file mode 100644 index 5536ecfd43ba3ccfc8a65ee171463a9c63fe5fb4..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/moe/trainer_test.py +++ /dev/null @@ -1,160 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for trainer.""" - -from absl.testing import absltest -import jax -import numpy as np -from t5x import metrics as metrics_lib -from t5x import models as models_lib -from t5x import optimizers -from t5x import train_state as train_state_lib -from t5x.contrib.moe import partitioning -from t5x.contrib.moe import trainer as trainer_lib -import tensorflow as tf - -mock = absltest.mock -jax.config.parse_flags_with_absl() - - -def fake_accum_grads( - model, optimizer, batch, rng, num_microbatches, data_partition_spec -): - del model, num_microbatches, rng, data_partition_spec - # Add `i` to each optimzer value. - i = batch['i'].sum() - grad_accum = jax.tree.map(lambda x: i, optimizer) - # Add j to each metric. - j = batch['j'].sum() - metrics = { - 'loss': metrics_lib.Sum.from_model_output(j), - 'accuracy': metrics_lib.Sum.from_model_output(j), - } - return grad_accum, metrics, None - - -def fake_apply_grads( - optimizer, - grad_accum, - metrics, - learning_rate, - weight_metrics_computer, - other_state_variables=None, -): - del weight_metrics_computer - del other_state_variables - metrics['learning_rate'] = metrics_lib.Sum.from_model_output(learning_rate) - optimizer = jax.tree.map(lambda x, g: x + g, optimizer, grad_accum) - return optimizer, metrics - - -class MoeTrainerTest(absltest.TestCase): - - def setUp(self): - super().setUp() - self.init_optimizer = optimizers.Optimizer( - optimizers.sgd(0.1), - state=optimizers.OptimizerState( - step=0, param_states={'expert_bias': 0, 'kernel': 0} - ), - target={'expert_bias': np.zeros(4), 'kernel': np.zeros((2, 4))}, - ) - self.init_train_state = train_state_lib.FlaxOptimTrainState( - self.init_optimizer - ) - train_state_axes = jax.tree.map(lambda x: None, self.init_train_state) - model_dir = self.create_tempdir().full_path - - mapfn = lambda i: {'i': [tf.cast(i, tf.int32)], 'j': [tf.cast(1, tf.int32)]} - self.dataset = ( - tf.data.Dataset.range(6).map(mapfn).batch(2, drop_remainder=True) - ) - - num_expert_partitions = 10 - self.test_trainer = trainer_lib.MoeTrainer( - model=mock.create_autospec(models_lib.BaseModel, instance=True), - train_state=self.init_train_state, - partitioner=partitioning.MoePjitPartitioner( - num_expert_partitions=num_expert_partitions, num_partitions=1 - ), - eval_names=['task1', 'task2'], - summary_dir=model_dir, - train_state_axes=train_state_axes, - rng=np.ones(2, np.uint32), - learning_rate_fn=lambda step: 2 * step, - num_microbatches=None, - num_expert_partitions=num_expert_partitions, - ) - - @mock.patch('t5x.trainer._time') - @mock.patch('t5x.trainer.accumulate_grads_microbatched', fake_accum_grads) - @mock.patch('t5x.trainer.apply_grads', fake_apply_grads) - def _test_train(self, precompile, mock_time=None): - trainer = self.test_trainer - initial_rng = trainer._base_rng - - if precompile: - mock_time.side_effect = [0, 1] - trainer.compile_train(next(self.dataset.as_numpy_iterator())) - trainer._compiled_train_step = mock.Mock( - side_effect=trainer._compiled_train_step - ) - - trainer._partitioned_train_step = mock.Mock( - side_effect=trainer._partitioned_train_step - ) - - # train start, logging, train end, logging - mock_time.side_effect = [1, 5] - num_steps = 2 - trainer.train(self.dataset.as_numpy_iterator(), num_steps) - - # Base rng must remain the same. - np.testing.assert_array_equal(trainer._base_rng, initial_rng) - - expected_optimizer = optimizers.Optimizer( - self.init_optimizer.optimizer_def, - state=optimizers.OptimizerState( - step=[6], - param_states={ - 'expert_bias': 60, # 10 * (0+1+2+3) = 60 - 'kernel': 6, # 0+1+2+3 = 6 - }, - ), - target={'expert_bias': 60 * np.ones(4), 'kernel': 6 * np.ones((2, 4))}, - ) - expected_train_state = train_state_lib.FlaxOptimTrainState( - expected_optimizer - ) - jax.tree.map( - np.testing.assert_allclose, trainer.train_state, expected_train_state - ) - - if precompile: - self.assertEqual(trainer._compiled_train_step.call_count, num_steps) - trainer._partitioned_train_step.assert_not_called() - else: - self.assertIsNone(trainer._compiled_train_step) - self.assertEqual(trainer._partitioned_train_step.call_count, num_steps) - - def test_train_noprecompile(self): - self._test_train(False) - - def test_train_precompile(self): - self._test_train(True) - - -if __name__ == '__main__': - absltest.main() diff --git a/t5x-main/t5x/contrib/moe/training_utils.py b/t5x-main/t5x/contrib/moe/training_utils.py deleted file mode 100644 index d5044b70eb1fff4f458e7c5dd4879720bc42a474..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/moe/training_utils.py +++ /dev/null @@ -1,143 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Extensions to Jax/Flax core functions for Mixture of Experts training. - -""" - -import dataclasses -import re -from typing import Any, Callable, Iterable, Optional, Sequence, Tuple, Union - -import flax -import jax -import numpy as np -from t5x import train_state - -# Type Stubs -ParamTree = Any -Gradients = Union[flax.core.FrozenDict, train_state.TrainState] - - -def match_fn(prefix: Optional[str]) -> Callable[[str], bool]: - """Creates a function returning true iff a string matches the prefix. - - Args: - prefix: Regex prefix to match. If none, then return match function will not - match any strings. - - Returns: - Prefix match function. - """ - if not prefix: - return lambda name: False - params_regex = re.compile(f'^{prefix}') - return lambda name: params_regex.match(name) is not None - - -def scale_sharded_grads( - grads: Gradients, - sharded_match_fn: Optional[Callable[[str], bool]], - scale_factor: float, -) -> Gradients: - """Scales sharded grads, identified by sharded_match_fn, by scale_factor. - - Args: - grads: Parameter gradients. - sharded_match_fn: Filter function for distinguishing sharded parameters from - replicated parameters. - scale_factor: Amount by which to scale sharded parameter gradients. - - Returns: - Gradients matching input, expect with sharded parameter gradients rescaled. - """ - if sharded_match_fn: - names_and_grads, tree_def = _tree_flatten_with_names(grads) - scaled_grads = [ - grad * scale_factor if sharded_match_fn(name) else grad - for name, grad in names_and_grads - ] - return tree_def.unflatten(scaled_grads) - else: - return grads - - -def tree_map_with_names(f, param_tree, match_name_fn=lambda name: True): - """Like jax.tree.map but with a filter on the leaf path name. - - Args: - f: The function to be applied to each parameter in `param_tree`. - param_tree: The tree of parameters `f` should be applied to. - match_name_fn: This function is called with each tree leave's path name, - which has a path-like format ('a/b/c'), and decides whether `f` should be - applied to that leaf or the leaf should be kept as-is. - - Returns: - A tree identical in structure to `param_tree` but with the leaves the - result of calling `f` on them in the cases where `match_name_fn` returns - True for that leaf's path name. - """ - names_and_vals, tree_def = _tree_flatten_with_names(param_tree) - vals = [f(v) if match_name_fn(name) else v for name, v in names_and_vals] - return tree_def.unflatten(vals) - - -def _tree_flatten_with_names( - tree: ParamTree, -) -> Tuple[Sequence[Tuple[str, Any]], jax.tree_util.PyTreeDef]: - """Like jax.tree_util.tree_flatten but also fetches leaf names. - - Specialized to parameter trees of the form {'key0': {'subkey0': Any}, ...}. - - Args: - tree: Tree of parameters to flatten. - - Returns: - - A list of leaf name and value pairs: [(name, value), ...]. - - A tree definition object representing the structure of the flattened tree. - """ - # PyTrees don't treat None values as leaves, so we explicitly declare them as - # such. - vals, tree_def = jax.tree_util.tree_flatten(tree, is_leaf=lambda x: x is None) - - # 'Fake' token tree that is use to track jax internal tree traversal and - # adjust our custom tree traversal to be compatible with it. - tokens = range(len(vals)) - token_tree = tree_def.unflatten(tokens) - val_names, perm = zip(*_traverse_with_names(token_tree)) - inv_perm = np.argsort(perm) - - # Custom traversal should visit the same number of leaves. - if len(val_names) != len(vals): - raise ValueError( - f'Pytree traversal detected {len(val_names)} names, ' - f'but {len(vals)} leafs.\nTreeDef is:\n{tree_def}' - ) - - return [(val_names[i], v) for i, v in zip(inv_perm, vals)], tree_def - - -def _traverse_with_names( - param_tree: ParamTree, -) -> Iterable[Tuple[str, ParamTree]]: - """Traverses nested dicts/dataclasses and emits (leaf_name, leaf_val).""" - if dataclasses.is_dataclass(param_tree): - param_tree = flax.serialization.to_state_dict(param_tree) - if isinstance(param_tree, (dict, flax.core.FrozenDict)): - keys = sorted(param_tree.keys()) - for key in keys: - for path, v in _traverse_with_names(param_tree[key]): - yield (key + '/' + path).rstrip('/'), v - else: - yield '', param_tree diff --git a/t5x-main/t5x/contrib/moe/training_utils_test.py b/t5x-main/t5x/contrib/moe/training_utils_test.py deleted file mode 100644 index 8cb47d48b072d43a3a4763caeabf5e96b2f33b1b..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/contrib/moe/training_utils_test.py +++ /dev/null @@ -1,101 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for training_utils.""" - -import functools -import os -# Emulate 2 devices on CPU. Import before JAX. -os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=2' - -from absl.testing import absltest # pylint: disable=g-import-not-at-top -from flax import core as flax_core -import jax -from jax import numpy as jnp -import numpy as np - -from t5x.contrib.moe import training_utils - - -class MatchFnTest(absltest.TestCase): - - def test_regex_prefix(self): - match_fn = training_utils.match_fn(r'.*test.*') - self.assertTrue(match_fn('/test/something')) - self.assertTrue(match_fn('to/test/or/not/')) - self.assertFalse(match_fn('no/match')) - - def test_empty_prefix(self): - match_fn = training_utils.match_fn(None) - self.assertFalse(match_fn('/test/something')) - self.assertFalse(match_fn('to/test/or/not/')) - - -class ScaleShardedGradsTest(absltest.TestCase): - - def test_scale_sharded_grads(self): - grads = flax_core.freeze({ - 'encoder': { - 'expert_layer': jnp.ones((2, 3)), - 'regular_layer': jnp.ones((1, 2)), - } - }) - sharded_match_fn = training_utils.match_fn(r'.*expert.*') - scaled_grads = training_utils.scale_sharded_grads( - grads, sharded_match_fn, scale_factor=100.0 - ) - - expected_grads = flax_core.freeze({ - 'encoder': { - 'expert_layer': 100.0 * jnp.ones((2, 3)), - 'regular_layer': jnp.ones((1, 2)), - } - }) - jax.tree.map( - functools.partial(np.testing.assert_allclose, rtol=3e-7), - scaled_grads, - expected_grads, - ) - - -class TreeTest(absltest.TestCase): - - def test_tree_flatten_with_names(self): - tree = {'ff_0': {'kernel': 0, 'bias': 1}, 'ff_1': {'kernel': 2, 'bias': 3}} - names_and_values, _ = training_utils._tree_flatten_with_names(tree) - - expected_names_and_values = [ - ('ff_0/bias', 1), - ('ff_0/kernel', 0), - ('ff_1/bias', 3), - ('ff_1/kernel', 2), - ] - self.assertEqual(names_and_values, expected_names_and_values) - - # Check that values match regular JAX tree_flatten. - self.assertEqual( - [x for _, x in names_and_values], jax.tree_util.tree_flatten(tree)[0] - ) - - def test_tree_map_with_names(self): - tree = {'a': 1, 'b': 2} - mapped_tree = training_utils.tree_map_with_names( - f=lambda x: -x, param_tree=tree, match_name_fn=lambda name: name == 'b' - ) - - self.assertEqual(mapped_tree, {'a': 1, 'b': -2}) - - -if __name__ == '__main__': - absltest.main() diff --git a/t5x-main/t5x/decoding.py b/t5x-main/t5x/decoding.py deleted file mode 100644 index fe58af5df75774f36dd9db41497937fa0b2e8493..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/decoding.py +++ /dev/null @@ -1,1569 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Fast decoding routines for inference from a trained model.""" - -import functools -from typing import Any, Callable, Mapping, Optional, Tuple, Union -import flax -from flax import traverse_util -import jax -from jax import lax -from jax import random -import jax.numpy as jnp -import numpy as np -from t5x import binary_search - -PyTree = Any -PyTreeDef = jax.tree_util.PyTreeDef - -# Constants -# "Effective negative infinity" constant for masking in beam search. -NEG_INF = np.array(-1.0e7) -NEG_INF_VALUE = -1.0e7 - -# Temperatures lower than this are considered 0.0, which is handled specially -# with a conditional. This is to avoid numeric issues from exponentiating on -# 1.0/temperature when temperature is close to 0.0. -MIN_TEMPERATURE = np.array(1e-4) - - -@flax.struct.dataclass -class DecodingState: - """Holds decoding state data. - - Used to communicate the current decoding state to tokens_to_logits methods. - Note that we use a different class than `SamplingLoopState` or `Beamstate` to - decouple the concerns of what data is useful for the loop vs. what the - sampling method needs. - Decodes for a given batch entry are flattened in a column-major way so that - decodes from the same batch entry are grouped together. - - Attributes: - cur_index: [batch_size * num_decodes] array position of the sampling loop in - the length dimension. - sequences: [batch_size * num_decodes, max_decode_len] array of current - sampled sequence prefixes. - cur_token: [batch_size * num_decodes] single timestep slice containing - current tokens. - cache: any mapping of arrays, e.g. flax attention cache. - """ - - cur_index: jnp.ndarray - sequences: jnp.ndarray - cur_token: jnp.ndarray - cache: Mapping[str, jnp.ndarray] - - -# ------------------------------------------------------------------------------ -# Temperature Sampling -# ------------------------------------------------------------------------------ - - -@flax.struct.dataclass -class SamplingLoopState: - """Holds sampling state data. - - Attributes: - step: Scalar decoding step count. Starts from zero. - cur_index: [batch_size * num_decodes] array position of the sampling loop in - the length dimension. - sequences: [batch_size * num_decodes, max_decode_len] array of current - sampled sequence prefixes. - cache: any mapping of arrays, e.g. flax attention cache. - cur_token: [batch_size * num_decodes] single timestep slice containing - current tokens. - ended: [batch_size * num_decodes] binary array marking completed sequences. - rng: Jax PRNGKey - log_prob: [batch_size * num_decodes] array of log probs for each sequence. - """ - - step: jnp.ndarray - cur_index: jnp.ndarray - sequences: jnp.ndarray - cache: Mapping[str, jnp.ndarray] - cur_token: jnp.ndarray - ended: jnp.ndarray - rng: jnp.ndarray - log_prob: jnp.ndarray - - -_dynamic_update_vector_slice_in_dim = jax.vmap( - lax.dynamic_update_slice_in_dim, in_axes=(0, 0, 0, None) -) - - -def _is_tracer(value: Any): - return isinstance(value, jax.core.Tracer) - - -StateCallbackFn = Callable[[SamplingLoopState], SamplingLoopState] -LogitCallbackFn = Callable[[jnp.ndarray, SamplingLoopState], jnp.ndarray] - - -def temperature_sample( - inputs: jnp.ndarray, - cache: Mapping[str, jnp.ndarray], - tokens_to_logits: Callable[ - [DecodingState], Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]] - ], - eos_id: int, - decode_rng: Optional[jnp.ndarray] = None, - num_decodes: int = 1, - temperature: Union[float, jnp.ndarray] = 1.0, - topk: int = 1, - topp: float = 0.0, - cache_offset: int = 0, - initial_index: Optional[jnp.ndarray] = None, - max_decode_steps: Optional[Union[int, jnp.ndarray]] = None, - max_decode_steps_hard_limit: Optional[int] = None, - rescale_log_probs: bool = True, - state_callback_fn: Optional[StateCallbackFn] = None, - logit_callback_fn: Optional[LogitCallbackFn] = None, -) -> Tuple[jnp.ndarray, jnp.ndarray]: - """Temperature sampling for language model generation. - - The temperature sampling is performed `num_decodes` times in a vectorized - manner by expanding the batch dimension. This is similar to how beam search - expands the batch dimension to process each batch element with multiple beams. - - This function dynamically updates the `inputs` array by sampling from the - model logits, which is provided by `tokens_to_logits` callable. The input - sequences are expanded at the end, populated and sliced by dropping the first - position. - - If `inputs` has non-zero entries, those values are not modified, i.e., - the sampled values for those positions are discarded. This simulates the - teacher forcing on the prefix positions. - - There are a few important observations related to this function. - - 1. The `inputs` is assumed to be a non-packed sequence. - - 2. If `initial_index=None`, then `inputs`[:, 0] is ignored. We will use 0 as a - BOS token to start the generation. This inherently assumes that `inputs` is - already shifted to the right by one position. If `initial_index=an_array`, - the token values at `inputs`[:, initial_index] are used as the token to - start the generation. - - 3. The loop index, i, is a vector of shape [batch_size]. When beginning - generation from scratch, each value will always have the same value. When - beginning with a partially filled cache, the loop index of different - elements can differ, via providing a value for `initial_index`. - - 3. Unless all batch elements generated the eos_id before reaching the end, we - always make `max_decode_len = inputs.shape[1]` number of calls to - `tokens_to_logits` when decoding from scratch and - `max_decode_len - jnp.minimum(initial_index)` number of calls when starting - from a partially filled cache. - - 4. Let `output` be the output sequences, i.e.,`sequences`[:, 1:]. Then - `output`[:, j] are the tokens generated when the while loop counter `i = - j`. Therefore, we generate the last token when `i = max_decode_len - 1` - and exit the while loop as all `i`s are incremented to `max_decode_len`. - - 5. Once `eos_id = 1` is generated, the subsequent predictions are all replaced - by padding token 0. - - 6. When using a partially filled cache, different batch elements can have - different lengths. This means an input that has a longer input will have - fewer steps until its `i` value reaches `max_decode_len` than an input with - a shorter input. We keep these longer examples alive, doing busy work - continually overwriting a new garbage token at the end of the sequence - until shorter examples finish. - - 7. When using a partially filled cache, providing a value for `initial_index`, - the attention cache index should be a vector of [batch_size]. - - We show three examples to illustrate how this function works. In addition to - input and output of the function, we also show two intermediate values: - `expanded_prompt_inputs` and `final_sequences`. Also for simplicity, the - examples are limited to `num_decodes = 1` usage and the `num_decodes` - dimension is omitted. - - ``` - Example 1: - inputs = [0, 5, 6, 1, 0] - expanded_prompt_inputs = [0, 5, 6, 1, 0, 0] - final_sequences = [0, 5, 6, 1, a, b] # before slicing. - output = [5, 6, 1, a, b] - where `a` is prediction while taking 1 as input and `b` is prediction while - taking `a` as input. - - Example 2 (early stopping): - inputs = [[0, 5, 1, 0, 0, 0, 0], - [0, 8, 0, 0, 0, 0, 0] - expanded_prompt_inputs = [[0, 5, 1, 0, 0, 0, 0, 0], - [0, 8, 0, 0, 0, 0, 0, 0] - final_sequences = [[0, 5, 1, a, b, c=1, 0, 0], - [0, 8, d, e, f=1, g=0, 0, 0]] - output = [[5, 1, a, b, c=1, 0, 0], - [8, d, e, f=1, g=0, 0, 0]] - - In this example, there are two sequences. Let's look at sequence 0. The - first generated token is `a`, which is in turn used to generate `b`. - Finally, `c = 1` is generated with the input `b`. Then the loop terminates - early because 1 is the `eos_id`. - - Now consider sequence 1. The when `f = 1` was generated, it is considered - done. Since sequence 0 is not done at this point, the next prediction, i.e., - `g` is zerod out. This continues until the end. - - Example 3 (prefilled cache): - inputs = [[0, 5, 2, 6, 1, 0], - [0, 8, 1, 0, 0, 0]] - expanded_prompt_inputs = [[0, 5, 2, 6, 1, 0, 0, 0], - [0, 8, 1, 0, 0, 0, 0, 0]] - max_decode_length = 6 - i = [4, 2] - input_tokens = [[1], - [1]] - output_tokens = [[a], - [b]] - expanded_prompt_inputs = [[0, 5, 2, 6, 1, a, 0, 0], - [0, 8, 1, b, 0, 0, 0, 0]] - i = [5, 3] - input_tokens = [[a], - [b]] - output_tokens = [[c], - [d]] - expanded_prompt_inputs = [[0, 5, 2, 6, 1, a, c, 0], - [0, 8, 1, b, d, 0, 0, 0]] - i = [6, 4] - input_tokens = [[c], - [d]] - output_tokens = [[y], - [e]] - expanded_prompt_inputs = [[0, 5, 2, 6, 1, a, c, y], - [0, 8, 1, b, d, e, 0, 0]] - i = [6, 5] - input_tokens = [[z], - [e]] - output_tokens = [[z], - [f]] - expanded_prompt_inputs = [[0, 5, 2, 6, 1, a, c, z], - [0, 8, 1, b, d, e, f, 0]] - i = [6, 6] - exit - outputs = [[5, 2, 6, 1, a, c], - [8, 1, b, d, e, f]] - - In this example, there are two sequences with different input lengths. Thus - the two caches had been filled to different positions. As we decode, the - first sequence hits the max decode length before the second. In order to - avoid prematurely ending decoding for the second sequence, the first - sequence continually overwrites the final token. - - Example 4 (prefilled cache and max decode steps): - inputs = [[0, 2, 0, 0, 0, 0, 0, 0], - [0, 3, 4, 0, 0, 0, 0, 0]] - expanded_prompt_inputs = [[0, 2, 0, 0, 0, 0, 0, 0, 0, 0] - [0, 3, 4, 0, 0, 0, 0, 0, 0, 0]] - initial_indices = [1, 2] - max_decode_step = 2 - - Then `max_decode_len = [3, 4]`. - i = [1, 2] - input_tokens = [[2], - [4]] - output_tokens = [[a], - [b]] - expanded_prompt_inputs = [[0, 2, a, 0, 0, 0, 0, 0, 0, 0] - [0, 3, 4, b, 0, 0, 0, 0, 0, 0]] - i = [2, 3]] - input_tokens = [[a], - [b]] - output_tokens = [[c], - [d]] - expanded_prompt_inputs = [[0, 2, a, c, 0, 0, 0, 0, 0, 0] - [0, 3, 4, b, d, 0, 0, 0, 0, 0]] - This is the last while loop iteration with i == max_decode_len - 1. - outputs = [[2, a, c, 0, 0, 0, 0, 0] - [3, 4, b, d, 0, 0, 0, 0]] - ``` - - Args: - inputs: array: [batch_size, max_decode_len] int32 sequence of tokens. - cache: flax attention cache. - tokens_to_logits: fast autoregressive decoder function taking single token - slices and cache and returning next-token logits and updated cache. - eos_id: int: end-of-sentence token for target vocabulary. - decode_rng: JAX PRNGKey. - num_decodes: number of decoded sequences to be returned. - temperature: float: sampling temperature factor. As it approaches zero this - becomes equivalent to greedy sampling. You may also provide an array of - floats of size batch_size to use different temperature values for each - batch item. - topk: integer: if nonzero only use the top-k logits to sample next token, if - zero don't use any cutoff and sample from full logits over vocabulary. - topp: float: if nonzero only use the smallest number of logits whose - cumulative sum of probs adds up to (at least) topp. Will raise ValueError - if it's nonzero when topk is nonzero. - cache_offset: axis offset for cache, arising from scanned layers. - initial_index: Optional[array]: [batch_size] int32 a vector of loop indexes - to start decoding at. - max_decode_steps: int: an optional maximum number of decoding steps. If - None, it will decode until the full input shape `inputs.shape[1]` is - filled. max_decode_steps begins counting after the prompt, so it will - decode at most len(prompt) + max_decode_steps tokens. - max_decode_steps_hard_limit: int: an optional fixed hard limit on - max_decode_steps. If this is set (not None and > 0), and max_decode_steps - is also set, then max_decode_steps will be clipped to this limit. The - value max_decode_steps can be an ndarray, but max_decode_steps_hard_limit - must be a Python integer or None. - rescale_log_probs: bool: whether to apply temperature, topp, and topk - rescaling to the log probs which are returned. If True, the log_probs will - include these transformations (for example, with topk=1, all log_probs - will be identically 0.0). If False, the log_probs will not be affected, - and topk/topp/temperature will not affect sequence probabilities. - state_callback_fn: Function that modifies the sampling loop state before - each step. This can be used to manipulate any part of the state either on - the accelerator or on the host using host callback. The function should - take a SamplingLoopState as argument, and it returns the updated state. - See `decoding_test.py` for an example usage. - logit_callback_fn: Function that modifies the logits before each temperature - sampling step. The function should take arguments (logits, state) and it - should return the modified logits. See `decoding_test.py` for an example - usage. - - Returns: - A tuple (decodes, log_prob) where `decodes` is sampled sequences with shape - [batch_size, num_decodes, max_decode_len] sorted by `log_prob`, which is log - probability of each of the sampled sequences. - """ - if decode_rng is None: - decode_rng = jax.random.PRNGKey(0) - - if ( - max_decode_steps_hard_limit is not None - and max_decode_steps_hard_limit > 0 - and max_decode_steps is not None - ): - max_decode_steps = jnp.minimum( - max_decode_steps, max_decode_steps_hard_limit - ) - - if num_decodes > 1: - # [batch, len] -> [batch * num_decodes, len] - expanded_inputs = flat_batch_beam_expand(inputs, num_decodes) - expanded_cache = cache_map( - functools.partial( - flat_batch_beam_expand, beam_size=num_decodes, offset=cache_offset - ), - cache, - # When we start with a prefilled cache, the cache index is no longer a - # scalar that will broadcast across multiple decodes, it is a vector and - # needs to be updated to handle the multiple decodes. - apply_to_index=initial_index is not None, - ) - if initial_index is not None: - initial_index = flat_batch_beam_expand(initial_index, num_decodes) - else: - expanded_inputs = inputs - expanded_cache = cache - - # expanded_decodes: [batch * num_decodes, len] - # expanded_log_prob: [batch * num_decodes] - expanded_decodes, expanded_log_prob = _temperature_sample_single_trial( - expanded_inputs, - expanded_cache, - tokens_to_logits, - eos_id, - decode_rng, - num_decodes, - temperature, - topk, - topp, - initial_index=initial_index, - max_decode_steps=max_decode_steps, - rescale_log_probs=rescale_log_probs, - state_callback_fn=state_callback_fn, - logit_callback_fn=logit_callback_fn, - ) - - batch_size = inputs.shape[0] - # [batch * num_decodes, len] -> [batch, num_decodes, len] - decodes = unflatten_beam_dim(expanded_decodes, batch_size, num_decodes) - # [batch * num_decodes] -> [batch, num_decodes] - log_prob = unflatten_beam_dim(expanded_log_prob, batch_size, num_decodes) - - # Sort `decodes` and `log_prob` by increasing log probabilities of the sampled - # sequence. - # [batch, num_decodes, 1] - idxs = jnp.expand_dims(jnp.argsort(log_prob, axis=-1), axis=-1) - - # returns [batch, num_decodes, len], [batch, num_decodes] in sorted order. - return jnp.take_along_axis(decodes, idxs, axis=1), jnp.take_along_axis( - log_prob, jnp.squeeze(idxs, axis=-1), axis=-1 - ) - - -def _temperature_sample_single_trial( - inputs: jnp.ndarray, - cache: Mapping[str, jnp.ndarray], - tokens_to_logits: Callable[ - [DecodingState], Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]] - ], - eos_id: int, - prng_key: jnp.ndarray, - num_decodes: int = 1, - temperature: Union[float, jnp.ndarray] = 1.0, - topk: int = 20, - topp: Union[float, jnp.ndarray] = 0.0, - initial_index: Optional[jnp.ndarray] = None, - max_decode_steps: Optional[Union[int, jnp.ndarray]] = None, - rescale_log_probs: bool = True, - state_callback_fn: Optional[StateCallbackFn] = None, - logit_callback_fn: Optional[LogitCallbackFn] = None, -) -> Tuple[jax.Array, jax.Array]: - """A helper function for `temperature_sample`.""" - - # We can check the values of topp and topk only if they are not dynamic. - if not (_is_tracer(topp) or _is_tracer(topk)) and topp and topk: - raise ValueError('At most one of `topp` or `topk` may be non-zero.') - - batch_size, max_decode_len = inputs.shape - - if max_decode_steps is not None: - # We can check the max_decode_steps bounds only if it is not dynamic. - if not _is_tracer(max_decode_steps) and max_decode_steps > inputs.shape[1]: - raise ValueError('Cannot decode more steps than the sequence length.') - - # The number of decode steps required to process the prefix is the number - # of non-zero tokens, since inputs[0] == 0 is the BOS token. - # `max_decode_len[j]` is the number of non-padding tokens in the jth element - # of the returned sequences capped at `len(inputs)`, assuming that the - # early stop doesn't occur. This is true with or without - # `max_decode_steps`. - # When the while loop index `i` for the `j`th element `i[j] = - # max_decode_len[j] - 1`, the generated token populate sequences[i[j]+1]]. - # Since sequences[:, 0] is BOS token, the generated token is - # `max_decode_len[j]`th non-padding tokens and hence `j`th element is - # ended. - max_decode_len = jnp.sum(inputs != 0, axis=1) + max_decode_steps - max_decode_len = jnp.minimum(inputs.shape[1], max_decode_len) - - # In the case of starting generation from a non-zero index, it is possible for - # one batch element to reach `max_decode_len` number of decoding steps before - # another. In order to let the last element decoder all the way to - # `max_decode_len` number of steps, we add a final garbage token to the end of - # the sequences. Any element that has reached `max_decode_len` before the rest - # of the elements will continually overwrite this token until all elements - # finish. - # [batch, length+1] -> [batch, length+2] - extra_input_tokens = 2 - expanded_prompt_inputs = jnp.append( - inputs, - jnp.zeros((batch_size, extra_input_tokens), dtype=inputs.dtype), - axis=1, - ) - end_marker = jnp.array(eos_id) - - temperature = jnp.asarray(temperature) - - # Initialize sampling loop state. - step = jnp.zeros((), dtype=jnp.int32) - # initial loop PRNGKey - rng0 = prng_key - # the per batch-item holding current token in loop. - if initial_index is None: - # the per batch-item loop position counter. - i0 = jnp.zeros((batch_size), dtype=jnp.int32) - # the per batch-item holding current token in loop. - token0 = jnp.zeros((batch_size, 1), dtype=jnp.int32) - else: - # the per batch-item loop position counter. - i0 = initial_index - # the per batch-item holding current token in loop. - # Select the token that the initial index is pointing to. - token0 = jnp.take_along_axis( - expanded_prompt_inputs, jnp.expand_dims(i0, axis=1), axis=1 - ) - # per batch-item state bit indicating if sentence has finished. - ended0 = jnp.zeros((batch_size, 1), dtype=jnp.bool_) - # (batch, length+2) array containing prefix prompt tokens for sampling loop - # as well as the generated output of newly sampled tokens. - sequences0 = expanded_prompt_inputs - log_prob0 = jnp.zeros((batch_size,), dtype=jnp.float32) - sampling_loop_init_state = SamplingLoopState( - step, i0, sequences0, cache, token0, ended0, rng0, log_prob0 - ) - # Initial eos count to be used to determine whether eos is "generated". Many - # inputs follow the format bos, inputs..., eos, targets..., eos. By counting - # the number of eos tokens we can detect when a new one is added, instead of - # just finding the one that probably ends the inputs. - # [batch, 1] - initial_eos_count = jnp.sum(sequences0 == end_marker, axis=-1, keepdims=True) - - def sampling_loop_cond_fn(state: SamplingLoopState) -> jax.Array: - """Sampling loop termination condition.""" - # Have all sampled sequences reached an end marker? - # Different elements in the batch can be at different loop indices, if any - # of our examples are not at the end, keep going. - all_sequences_ended = jnp.all(state.ended) - return ~all_sequences_ended - - def sampling_loop_body_fn(state: SamplingLoopState) -> SamplingLoopState: - """Sampling loop state update.""" - - if state_callback_fn is not None: - state = state_callback_fn(state) - - # Split RNG for sampling. - rng1, rng2 = random.split(state.rng) - # Call fast-decoder model on current tokens to get next-position logits. - frozen = isinstance(state.cache, flax.core.FrozenDict) - decoding_state = DecodingState( - cur_index=state.cur_index, - sequences=state.sequences[:, :-extra_input_tokens], - cur_token=state.cur_token, - cache=state.cache, - ) - logits, new_cache = tokens_to_logits(decoding_state) - if frozen: - new_cache = flax.core.freeze(new_cache) - # Sample next token from logits. - - if logit_callback_fn is not None: - logits = logit_callback_fn(logits, state) - - def sample_logits_with_nonzero_temperature(logits, temperature): - scaled_logits = logits / jnp.maximum(temperature, MIN_TEMPERATURE) - if _is_tracer(topk) or topk: - scaled_logits = jax.lax.cond( - topk > 0, - lambda: binary_search.topk_mask(scaled_logits, topk, NEG_INF), # pytype: disable=wrong-arg-types # jax-ndarray - lambda: scaled_logits, - ) - - # When topp is dynamic, we always use it since we cannot check - # non-zeroness (but it will have no effect if topp is 0.0). - if _is_tracer(topp) or topp: - scaled_logits = binary_search.topp_mask(scaled_logits, topp, NEG_INF) # pytype: disable=wrong-arg-types # jax-ndarray - - # [batch] - next_token = random.categorical(rng1, scaled_logits).astype(jnp.int32) - - # log probability of the current token conditioned on the previously - # sampled and prefix tokens. - # [batch, vocab] -> [batch, vocab] - if rescale_log_probs: - log_probs = jax.nn.log_softmax(scaled_logits) - else: - log_probs = jax.nn.log_softmax(logits) - # [batch, vocab] -> [batch] - next_log_prob = jnp.squeeze( - jnp.take_along_axis( - log_probs, jnp.expand_dims(next_token, axis=1), axis=-1 - ), - axis=-1, - ) - - return (next_token, next_log_prob) - - def sample_logits_with_zero_temperature(logits, temperature): # pylint: disable=unused-argument - # For zero temperature, we always want the greedy output, regardless - # of the values of topk and topp. - - next_token = jnp.argmax(logits, -1).astype(jnp.int32) - - if rescale_log_probs: - next_log_prob = jnp.zeros_like(next_token, dtype=jnp.float32) - else: - log_probs = jax.nn.log_softmax(logits) - next_log_prob = jnp.squeeze( - jnp.take_along_axis( - log_probs, jnp.expand_dims(next_token, axis=1), axis=-1 - ), - axis=-1, - ) - - return (next_token, next_log_prob) - - # Perform sampling with temperature - if len(temperature.shape) == 1: - # Each batch item can have different temperatures. - def map_logits_with_different_temperatures( - logits_batch_item, temperature_batch_item - ): - return lax.cond( - temperature_batch_item > MIN_TEMPERATURE, - sample_logits_with_nonzero_temperature, - sample_logits_with_zero_temperature, - jnp.expand_dims(logits_batch_item, axis=0), - temperature_batch_item, - ) - - (next_token, next_log_prob) = jax.vmap( - map_logits_with_different_temperatures - )(logits, jnp.repeat(temperature, num_decodes)) - next_token = jnp.squeeze(next_token, axis=-1) - next_log_prob = jnp.squeeze(next_log_prob, axis=-1) - else: - # Single temperature value is applied to all batch items. - (next_token, next_log_prob) = lax.cond( - temperature > MIN_TEMPERATURE, - sample_logits_with_nonzero_temperature, - sample_logits_with_zero_temperature, - logits, - temperature, - ) - - # When different batch elements are at different points in the loop counter, - # it is possible that an element that started at a higher index will reach - # `max_decode_len` before other elements. When this happens we need to make - # sure this element continuous overwrites our new garbage collection index. - # Here we clamp `i` to `max_decode_len`. This will cause the a write to - # `max_decode_len + 1` which is the final index in `sequences`. Subsequent - # loop body executions will also get their value clamped causing continual - # overwriting of the final garbage position until all examples are finished. - i = jnp.minimum(state.cur_index, max_decode_len) - - # Only use sampled tokens if we're past provided prefix tokens. - # Select the next token from sequences. - # [batch] - next_input_token = jnp.squeeze( - jnp.take_along_axis( - state.sequences, jnp.expand_dims(i + 1, axis=1), axis=1 - ), - axis=1, - ) - # Check if the next token is padding (a target) or non-padding (an input). - # Mask will have `1` for targets and `0` for inputs. - out_of_prompt = next_input_token == 0 - # Select the sampled next token for targets and the actual next token for - # inputs (teacher forcing). - # [batch] - next_token = next_token * out_of_prompt + next_input_token * ~out_of_prompt - - # only add probability if outside prefix region - # [batch] -> [batch] - next_log_prob = state.log_prob + ( - next_log_prob * out_of_prompt - ) * jnp.squeeze(~state.ended, axis=-1).astype(jnp.int32) - - # [batch] -> [batch, 1] - next_token = jnp.expand_dims(next_token, axis=-1) - - # If end-marker reached for batch item, only emit padding tokens. - # [batch, 1] * [batch, 1] -> [batch, 1] - next_token_or_endpad = next_token * ~state.ended - # Add current sampled tokens to recorded sequences. - one_hot = jax.nn.one_hot( - i + 1, state.sequences.shape[1], dtype=state.sequences.dtype - ) - new_sequences = ( - state.sequences * (1 - one_hot) + next_token_or_endpad * one_hot - ) - # new_sequences = dynamic_update_vector_slice_in_dim(sequences, - # next_token_or_endpad, - # i + 1, - # 0) - # Count eos tokens in the sequences and compare to the initial count - # [batch, 1] - cur_eos_count = jnp.sum(new_sequences == end_marker, axis=-1, keepdims=True) - # [batch, 1] - - # Have we reached max decoding length? - # We generally index into sequences[:, i + 1], and sequences.shape[1] = - # max_decode_len + 2, therefore i == max_decode_len - 1 will write to - # sequences[-2] which is our last valid location. i == max_decode_len will - # write to sequences[-1] which is our garbage collection token. Thus `i` - # should be strictly less than max_decode_len. - has_additional_eos = cur_eos_count > initial_eos_count - ended = ( - state.ended - | has_additional_eos - | jnp.expand_dims(i >= max_decode_len - 1, axis=1) - ) - - return SamplingLoopState( - state.step + 1, - i + 1, - new_sequences, - new_cache, - next_token_or_endpad, - ended, - rng2, - next_log_prob, - ) - - # Run sampling loop and collect final state. - final_state = lax.while_loop( - sampling_loop_cond_fn, sampling_loop_body_fn, sampling_loop_init_state - ) - - if state_callback_fn is not None: - final_state = state_callback_fn(final_state) - - # Pick part of the state corresponding to the sampled sequences. - final_sequences = final_state.sequences - log_prob = final_state.log_prob - # Drop the first position because they are dummy bos tokens. Drop the new - # garbage collection token at the end too. - return final_sequences[:, 1:-1], log_prob - - -# ------------------------------------------------------------------------------ -# BEAM Sampling -# ------------------------------------------------------------------------------ - - -def brevity_penalty(alpha: float, length: int) -> jnp.ndarray: - """Brevity penalty function for beam search penalizing short sequences. - - Args: - alpha: float: brevity-penalty scaling parameter. - length: int: length of considered sequence. - - Returns: - Brevity penalty score as jax scalar. - """ - return jnp.power(((5.0 + length) / 6.0), alpha) - - -# Beam handling utility functions: - - -def cache_map(fn, cache, apply_to_index: bool = False): - """Maps function over that caches, even multiple caches in various layers. - - Args: - fn: The function to apply. - cache: The cache to apply it to. - apply_to_index: Whether to apply the function to the cache index. - - Returns: - The result of applying `fn` to the cache. - """ - frozen = isinstance(cache, flax.core.FrozenDict) - if frozen: - cache = flax.core.unfreeze(cache) - flat_cache = traverse_util.flatten_dict(cache) - if apply_to_index: - keyvals = flat_cache - else: - keyvals = {k: v for k, v in flat_cache.items() if k[-1] != 'cache_index'} - # Exclude cached relative position bias from beam expansion, etc. - # Also excludes scalar index in absolute position embedder from expansion. - # TODO(levskaya): generalize cache_map to accept a list of leaf names to - # map over, instead of doing this ad-hoc. - exclusion_list = ['cached_bias', 'position_embedder_index'] - keyvals = {k: v for k, v in keyvals.items() if k[-1] not in exclusion_list} - - keyvals = jax.tree.map(fn, keyvals) - flat_cache.update(keyvals) - new_cache = traverse_util.unflatten_dict(flat_cache) - if frozen: - new_cache = flax.core.freeze(new_cache) - return new_cache - - -def add_beam_dim( - x: jnp.ndarray, beam_size: int, offset: int = 0 -) -> jnp.ndarray: - """Creates new beam dimension in non-scalar array and tiles into it.""" - x = jnp.expand_dims(x, axis=offset + 1) - tile_dims = [1] * x.ndim - tile_dims[offset + 1] = beam_size - return jnp.tile(x, tile_dims) - - -def flatten_beam_dim(x: jnp.ndarray, offset: int = 0) -> jnp.ndarray: - """Flattens the first two dimensions of a non-scalar array.""" - xshape = list(x.shape) - b_sz = xshape.pop(offset) - xshape[offset] *= b_sz - return x.reshape(xshape) - - -def unflatten_beam_dim( - x: jnp.ndarray, batch_size: int, beam_size: int, offset: int = 0 -) -> jnp.ndarray: - """Unflattens the first, flat batch*beam dimension of a non-scalar array.""" - assert batch_size * beam_size == x.shape[offset] - xshape = list(x.shape) - newshape = xshape[:offset] + [batch_size, beam_size] + xshape[offset + 1 :] - return x.reshape(newshape) - - -def flat_batch_beam_expand( - x: jnp.ndarray, beam_size: int, offset: int = 0 -) -> jnp.ndarray: - """Expands the each batch item by beam_size in batch_dimension.""" - return flatten_beam_dim(add_beam_dim(x, beam_size, offset), offset) - - -def cache_gather_beams( - nested: PyTree, - beam_indices: jnp.ndarray, - batch_size: int, - old_beam_size: int, - new_beam_size: int, - one_hot: bool = True, - offset: int = 0, -) -> jnp.ndarray: - """Gathers the cache beam slices indexed by beam_indices into new beam array. - - Args: - nested: cache pytree. - beam_indices: array of beam_indices - batch_size: size of batch. - old_beam_size: size of _old_ beam dimension. - new_beam_size: size of _new_ beam dimension. - one_hot: whether to perform gathers by one-hot contraction or directly. - offset: cache axis offset from scanned layers. - - Returns: - New pytree with new beam arrays. - [batch_size, old_beam_size, ...] --> [batch_size, new_beam_size, ...] - """ - assert offset in (0, 1), 'general offsets not supported' - if one_hot: - # Gather via one-hot contraction, needed for SPMD partitioning. - oh_beam_indices = jax.nn.one_hot( - beam_indices, old_beam_size, dtype=jnp.int32 - ) - if offset == 0: - - def gather_fn(x): - return jnp.einsum('beo,bo...->be...', oh_beam_indices, x).astype( - x.dtype - ) - - else: - - def gather_fn(x): - return jnp.einsum('beo,lbo...->lbe...', oh_beam_indices, x).astype( - x.dtype - ) - - return cache_map(gather_fn, nested) # pytype: disable=bad-return-type # jax-ndarray - - else: - # True gather via fancy indexing. - batch_indices = jnp.reshape( - jnp.arange(batch_size * new_beam_size) // new_beam_size, - (batch_size, new_beam_size), - ) - if offset == 0: - - def gather_fn(x): - return x[batch_indices, beam_indices] - - else: - - def gather_fn(x): - return x[:, batch_indices, beam_indices] - - return cache_map(gather_fn, nested) - - -def gather_beams( - nested: PyTree, - beam_indices: jnp.ndarray, - batch_size: int, - old_beam_size: int, - new_beam_size: int, - one_hot: bool = True, -) -> jnp.ndarray: - """Gathers the beam slices indexed by beam_indices into new beam array. - - Args: - nested: pytree of arrays or scalars (the latter ignored). - beam_indices: array of beam_indices - batch_size: size of batch. - old_beam_size: size of _old_ beam dimension. - new_beam_size: size of _new_ beam dimension. - one_hot: whether to perform gathers by one-hot contraction or directly. - - Returns: - New pytree with new beam arrays. - [batch_size, old_beam_size, ...] --> [batch_size, new_beam_size, ...] - """ - if one_hot: - # Gather via one-hot contraction, needed for SPMD partitioning. - oh_beam_indices = jax.nn.one_hot( - beam_indices, old_beam_size, dtype=jnp.int32 - ) - - def gather_fn(x): - return jnp.einsum('beo,bo...->be...', oh_beam_indices, x).astype(x.dtype) - - return jax.tree.map(gather_fn, nested) - else: - # True gather via fancy indexing. - batch_indices = jnp.reshape( - jnp.arange(batch_size * new_beam_size) // new_beam_size, - (batch_size, new_beam_size), - ) - - def gather_fn(x): - return x[batch_indices, beam_indices] - - return jax.tree.map(gather_fn, nested) - - -def top_k_two_stage(x, k): - """Wrapper around lax.top_k with low-batch optimization. - - Args: - x: tensor with shape f32[batch, num_samples]. - k: integer indicating how many top values to return. - - Returns: - Largest k values and indices with shape (f32[batch, k], s32[batch, k]). - """ - - batch, num_samples = x.shape - num_lanes = 128 - if isinstance(batch, int) and batch <= 8 and num_samples > 8 * num_lanes * k: - # At small batch, when num_samples is sufficiently large, optimize - # execution on TPU by doing TopK in two stages. Reshaping 'x' to fill - # lanes reduces tensor padding in TopK call. - if num_samples % num_lanes != 0: - # Pad input tensor to multiples of num_lanes. - num_samples_rounded_up = num_samples + ( - num_lanes - num_samples % num_lanes - ) - x = jnp.pad( - x, - ((0, 0), (0, num_samples_rounded_up - num_samples)), - mode='constant', - constant_values=-np.inf, - ) - num_samples = num_samples_rounded_up - # Reshape input tensor to fill lanes. - num_samples_sublanes = int(num_samples / num_lanes) - x_reshaped = jnp.reshape(x, (batch * num_lanes, num_samples_sublanes)) - # First stage top_k. - vals, indices = lax.top_k(x_reshaped, k) - indices = jnp.reshape(indices, (batch, num_lanes, k)) - index_offsets = jnp.reshape( - num_samples_sublanes * jnp.arange(num_lanes), (1, num_lanes, 1) - ) - indices = jnp.reshape( - jnp.add(index_offsets, indices), (batch, num_lanes * k) - ) - vals = jnp.reshape(vals, (batch, num_lanes * k)) - # Second stage top_k. - vals_s2, indices_s2 = lax.top_k(vals, k) - indices_s2 = jnp.take_along_axis(indices, indices_s2, axis=1) - return vals_s2, indices_s2 - else: - # Use default TopK implementation. - return lax.top_k(x, k) - - -def gather_topk_beams( - nested: PyTree, - score_or_log_prob: jnp.ndarray, - batch_size: int, - new_beam_size: int, -) -> jnp.ndarray: - """Gathers the top-k beam slices given by score_or_log_prob array. - - Args: - nested: pytree of arrays or scalars (the latter ignored). - score_or_log_prob: [batch_size, old_beam_size] array of values to sort by - for top-k selection of beam slices. - batch_size: int: size of batch. - new_beam_size: int: size of _new_ top-k selected beam dimension - - Returns: - New pytree with new beam arrays containing top k new_beam_size slices. - [batch_size, old_beam_size, ...] --> [batch_size, new_beam_size, ...] - """ - _, topk_indices = lax.top_k(score_or_log_prob, k=new_beam_size) - topk_indices = jnp.flip(topk_indices, axis=1) - return gather_beams( - nested, - topk_indices, - batch_size, - score_or_log_prob.shape[1], - new_beam_size, - ) - - -# Beam search state: - - -@flax.struct.dataclass -class BeamState: - """Holds beam search state data.""" - - # The position of the decoding loop in the length dimension. - cur_index: jax.Array # scalar int32: current decoded length index - # The active sequence log probabilities and finished sequence scores. - live_logprobs: jax.Array # float32: [batch_size, beam_size] - finished_scores: jax.Array # float32: [batch_size, beam_size] - # The current active-beam-searching and finished sequences. - live_seqs: jax.Array # int32: [batch_size, beam_size, max_decode_len] - finished_seqs: jax.Array # int32: [batch_size, beam_size, - # max_decode_len] - # Records which of the 'finished_seqs' is occupied and not a filler slot. - finished_flags: jax.Array # bool: [batch_size, beam_size] - # The current state of the autoregressive decoding caches. - cache: PyTree # Any pytree of arrays, e.g. flax attention Cache object - # Optional array of initial indices from which decoding starts, will be either - # 0s if there is no prompt or None. - initial_index: jax.Array | None - - -def beam_init( - batch_size: int, - beam_size: int, - max_decode_len: int, - cache: Mapping[str, jnp.ndarray], - offset: int = 0, - live_seqs: Optional[jnp.ndarray] = None, - initial_index: Optional[jnp.ndarray] = None, -) -> BeamState: - """Initializes the beam search state data structure.""" - cur_index0 = jnp.array(0) - live_logprobs0 = jnp.tile( - jnp.array([0.0] + [NEG_INF] * (beam_size - 1)), [batch_size, 1] - ) - finished_scores0 = jnp.ones((batch_size, beam_size)) * NEG_INF - # If we prefill any part of the prompt, then the initial live sequences are - # provided. In reality these will be the last token of the prompt or BOS if - # the prompt (in the batch) is empty. - live_seqs0 = ( - live_seqs - if live_seqs is not None - else jnp.zeros((batch_size, beam_size, max_decode_len), jnp.int32) - ) - finished_seqs0 = jnp.zeros((batch_size, beam_size, max_decode_len), jnp.int32) - finished_flags0 = jnp.zeros((batch_size, beam_size), jnp.bool_) - # add beam dimension to attention cache pytree elements - # We will have to expand the cache_index if we're given an initial prompt that - # we prefill. - beam_cache0 = cache_map( - lambda x: add_beam_dim(x, beam_size, offset), - cache, - apply_to_index=live_seqs is not None, - ) - return BeamState( - cur_index=cur_index0, - live_logprobs=live_logprobs0, - finished_scores=finished_scores0, - live_seqs=live_seqs0, - finished_seqs=finished_seqs0, - finished_flags=finished_flags0, - cache=beam_cache0, - initial_index=initial_index, - ) - - -def _right_align_prompts(prompts): - """Right align the prompts.""" - - # Implementation note: - # - # A very short code to do this right aligning, would be to vmap a jnp.roll for - # the amount of padding in each example, i.e. max_len - prompt_max_index - # (+-1) - however this is slow. - # - # A faster way, courtesy Jeremiah Willcock, is to shift rows by bitmasking - # the gap and iterating for 1, 2, 4, ... log2(len) bitmasks. - # - # This gives a ~3x speedup over vmapping a roll. - - max_len = prompts.shape[1] - nbits = np.ceil(np.log2(max_len)).astype(np.int32) - indices = jnp.arange(max_len) - prompt_max_index = jnp.argmax((prompts != 0) * indices[None, :], axis=1) - shifts = max_len - prompt_max_index - 1 - for i in range(0, nbits + 1): - bitmask = 2**i - prompts = jnp.where( - jnp.expand_dims(shifts & bitmask, 1), - jnp.pad(prompts, ((0, 0), (bitmask, 0)))[:, :-bitmask], - prompts, - ) - return prompts - - -def _left_align_prompts(prompts): - """Left align the prompts.""" - # See implementation notes in `_right_align_prompts`. - - max_len = prompts.shape[1] - # [0, 1, 2, ... L - 1] - indices = jnp.arange(max_len) - # Indices of non padding positions - 1 based, since `indices` is 0 based. - non_padding_positions = (prompts != 0) * (indices[None, :] + 1) - # Replace all padding with `max_len + 1` - m = jnp.where(non_padding_positions, non_padding_positions, max_len + 1) - # First prompt's index. - shifts = jnp.argmin(m, axis=1) - temp = prompts - nbits = np.ceil(np.log2(max_len)).astype(np.int32) - for i in range(0, nbits + 1): - bitmask = 2**i - temp = jnp.where( - jnp.expand_dims(shifts & bitmask, 1), - jnp.pad(temp, ((0, 0), (0, bitmask)))[:, bitmask:], - temp, - ) - return temp - - -def _pick_last_prompt_token(prompts): - # prompts: i32[batch, length] - prompt_lengths = jnp.sum(prompts != 0, axis=1) - # return value: i32[batch,] - return prompts[jnp.arange(prompts.shape[0]), prompt_lengths] - - -# Beam search routine: -def beam_search( - inputs: jnp.ndarray, - cache: Mapping[str, jnp.ndarray], - tokens_to_logits: Callable[ - [DecodingState], Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]] - ], - eos_id: int, - num_decodes: int = 4, - alpha: float = 0.6, - max_decode_len: Optional[int] = None, - max_decode_step: int = -1, - min_log_prob: float = NEG_INF_VALUE, - decode_rng: Optional[jnp.ndarray] = None, - cache_offset: int = 0, - initial_index: Optional[jnp.ndarray] = None, -) -> Tuple[jnp.ndarray, jnp.ndarray]: - """Beam search for transformer machine translation. - - If `inputs` has non-zero entries, those values are not modified, i.e., - the sampled values for those positions are discarded. This simulates the - teacher forcing on the prefix positions. - - NOTE: While using initial_index with prompts of variable lengths - To comply with the max_decode_len length requirement, we might now return - sequences that were live (i.e. EOS not decoded yet) when they exceeded their - length allowance along with sequences that finished (i.e. EOS was decoded). - Furthermore there might be sequences that finished decoding after their - max_decode_len was finished, but would appear truncated in the output at - max_decode_len. - - TODO(afrozm): Solve this, if needed, by having a third class of sequences - apart from live and finished called "truncated", then after beam search - completes, we will order them as finished > truncated > live, rather than - finished > live that happens right now. - - Args: - inputs: array: [batch_size, length] int32 sequence of tokens. - cache: flax attention cache. - tokens_to_logits: fast autoregressive decoder function taking single token - slices and cache and returning next-token logits and updated cache. - eos_id: int: id of end-of-sentence token for target vocabulary. - num_decodes: number of decoded sequences to be returned. This is equivalent - to the number of beams used in the beam search. - alpha: float: scaling factor for brevity penalty. - max_decode_len: int: an optional maximum length of decoded sequence. If - None, it uses `inputs.shape[1]` as `max_decode_len`. - max_decode_step: int: maximum number of extra beam search steps allowed. - While using initial_index with prompts of variable lengths, max_decode_len - controls the output length, min_log_prob only stops the beam search if all - beam entries fail to pass the threshold, this will set a hard limit of - number of beam search steps to run. Useful to set a hard limit for the - serving latency. - min_log_prob: the beam search will stop if there is no live beam entry with - higher raw score (ignoring brevity penalty) than this. - decode_rng: Unused decoder RNG seed. - cache_offset: axis offset for cache, arising from scanned layers. - initial_index: Optional[jnp.ndarray], the index from which to start decoding - autoregressively if set. If unset, then we teacher-force the prefix, but - autoregressively (so it will be slow). When set, this also assumes that - the cache is appropriately populated. Since inputs are padded on the left - with BOS = 0, these are also the lengths of the prompts. - - Returns: - Tuple of: - [batch_size, beam_size, max_decode_len] top-scoring sequences - [batch_size, beam_size] beam-search scores. - """ - del decode_rng - # We liberally annotate shape information for clarity below. - - beam_size = num_decodes - - batch_size = inputs.shape[0] - end_marker = jnp.array(eos_id) - if max_decode_len is None: - max_decode_len = inputs.shape[1] - # We start with a dummy token in the beginning so extend the maximum length. - max_decode_len += 1 - - right_aligned_input = None - live_seqs = None - if initial_index is not None: - # Now contains the inputs, but "right aligned" so as to end with the last - # prompt token. - # [batch_size, length] - right_aligned_input = _right_align_prompts(inputs) - # `inputs` now is just the last token of the prompt, right padded to the - # same as before. - length = inputs.shape[1] - inputs = jnp.pad( - right_aligned_input[:, -1][:, None], - ((0, 0), (0, length - 1)), - constant_values=0, - ) - - # Sized [batch, max_decode_len] - live_seqs = jnp.pad( - right_aligned_input[:, -1][:, None], - ((0, 0), (0, max_decode_len - 1)), - constant_values=0, - ) - live_seqs = jnp.expand_dims(live_seqs, axis=1) - live_seqs = jnp.broadcast_to( - live_seqs, (live_seqs.shape[0], num_decodes, live_seqs.shape[-1]) - ) - else: - initial_index = jnp.zeros((batch_size,), dtype=jnp.int32) - - # initialize beam search state - beam_search_init_state = beam_init( - batch_size, - beam_size, - max_decode_len, - cache, - cache_offset, - live_seqs=live_seqs, - initial_index=initial_index, - ) - - def beam_search_loop_cond_fn(state: BeamState) -> jax.Array: - """Beam search loop termination condition.""" - # Have we reached max decoding length? - - # Since we might be starting at different points in the prompts, let's use - # the minimum prompt length to stop conservatively. - cur_index = state.cur_index + jnp.min(state.initial_index) - # Because we mutate the "i+1" position, we stop one token before the end. - not_at_end = cur_index < max_decode_len - 1 - - # If we have ran out of max number of allowed beam search steps. - # Note braces are needed as & has higher precedence than >. - exceed_max_decode_step = (max_decode_step > 0) & jnp.all( - state.cur_index >= max_decode_step - ) - - # Is no further progress in the beam search possible? - # Get the best possible scores from alive sequences. - min_brevity_penalty = brevity_penalty(alpha, max_decode_len) - best_live_scores = state.live_logprobs[:, -1:] / min_brevity_penalty - # Get the worst scores from finished sequences. - worst_finished_scores = jnp.min( - state.finished_scores, axis=1, keepdims=True - ) - # Mask out scores from slots without any actual finished sequences. - worst_finished_scores = jnp.where( - state.finished_flags, worst_finished_scores, NEG_INF - ) - # If no best possible live score is better than current worst finished - # scores, the search cannot improve the finished set further. - search_terminated = jnp.all(worst_finished_scores > best_live_scores) - - # If no best possible live score is greater than min_log_prob, end search - # early. Note: - # - We are ignoring the brevity penalty as it can over-estimate the scores. - # - state.cur_index > 0 is needed as beam search just starts and live - # beams are empty. - raw_min_log_prob = min_log_prob / min_brevity_penalty - none_pass_min_prob_check = jnp.all(best_live_scores < raw_min_log_prob) & ( - state.cur_index > 0 - ) - - # If we're not at the max decode length, there is at least one beam passing - # the minimum probability threshold, the search hasn't terminated, and it - # doesn't exceed the maximum number of beam search steps, continue looping. - return ( - not_at_end - & (~search_terminated) - & (~none_pass_min_prob_check) - & (~exceed_max_decode_step) - ) - - def beam_search_loop_body_fn(state: BeamState) -> BeamState: - """Beam search loop state update function.""" - # Collect the current position slice along length to feed the fast - # autoregressive decoder model. Flatten the beam dimension into batch - # dimension for feeding into the model. - # --> [batch * beam, 1] - flat_ids = flatten_beam_dim( - lax.dynamic_slice( - state.live_seqs, (0, 0, state.cur_index), (batch_size, beam_size, 1) - ) - ) - # Flatten beam dimension into batch to be compatible with model. - # {[batch, beam, ...], ...} --> {[batch * beam, ...], ...} - flat_cache = cache_map( - functools.partial(flatten_beam_dim, offset=cache_offset), state.cache - ) - - # Call fast-decoder model on current tokens to get next-position logits. - # --> [batch * beam, vocab] - decoding_state = DecodingState( - cur_index=state.cur_index, - sequences=flatten_beam_dim(state.live_seqs), - cur_token=flat_ids, - cache=flat_cache, - ) - flat_logits, new_flat_cache = tokens_to_logits(decoding_state) - - # unflatten beam dimension - # [batch * beam, vocab] --> [batch, beam, vocab] - logits = unflatten_beam_dim(flat_logits, batch_size, beam_size) - # Unflatten beam dimension in attention cache arrays - # {[batch * beam, ...], ...} --> {[batch, beam, ...], ...} - new_cache = cache_map( - lambda x: unflatten_beam_dim(x, batch_size, beam_size, cache_offset), - new_flat_cache, - ) - - # Gather log probabilities from logits - candidate_log_probs = jax.nn.log_softmax(logits) - # Add new logprobs to existing prefix logprobs. - # --> [batch, beam, vocab] - log_probs = candidate_log_probs + jnp.expand_dims( - state.live_logprobs, axis=2 - ) - - # We'll need the vocab size, gather it from the log probability dimension. - vocab_size = log_probs.shape[-1] - - # Each item in batch has beam_size * vocab_size candidate sequences. - # For each item, get the top 2*k candidates with the highest log- - # probabilities. We gather the top 2*K beams here so that even if the best - # K sequences reach EOS simultaneously, we have another K sequences - # remaining to continue the live beam search. - beams_to_keep = 2 * beam_size - # Flatten beam and vocab dimensions. - flat_log_probs = log_probs.reshape((batch_size, beam_size * vocab_size)) - # Gather the top 2*K scores from _all_ beams. - # --> [batch, 2*beams], [batch, 2*beams] - topk_log_probs, topk_indices = top_k_two_stage( - flat_log_probs, k=beams_to_keep - ) - - # Append the most probable 2*K token IDs to the top 2*K sequences - # Recover token id by modulo division. - topk_ids = topk_indices % vocab_size - # Force decode `inputs` into topk_ids up until PAD. When `inputs` is all - # PADs this is a no-op. - # - # Also note that when `initial_index` is set, we've already setup the - # inputs so that at position 1 onwards (i.e. state.cur_index + 1 >= 1) - # the tokens are 0 and we'll immediately be "out of prompt". - # --> [batch_size, 1] - next_input_token = jnp.expand_dims(inputs, axis=1).astype(jnp.int32)[ - :, :, state.cur_index + 1 - ] - # --> [batch_size, 1] - out_of_prompt = next_input_token == 0 - - # When forcing prompts, update log probabilities to `0` for the top of the - # beam and -INF for the rest, effectively keeping only one beam alive. - # This is necessary, because if two beams have the same prefix, then they - # will both decode the exact same sequences and that's redundant. - # --> [batch, 2*beams] - inside_prompt_log_probs = jnp.concatenate( - [ - jnp.zeros((batch_size, 1), dtype=topk_log_probs.dtype), - jnp.full_like(topk_log_probs[:, : beams_to_keep - 1], NEG_INF), - ], - axis=1, - ) - topk_log_probs = ( - topk_log_probs * out_of_prompt - + inside_prompt_log_probs * ~out_of_prompt - ) - - topk_ids = topk_ids * out_of_prompt + next_input_token * ~out_of_prompt - - # Expand id array for broadcasting - # --> [batch, 2*beams, 1] - topk_ids = jnp.expand_dims(topk_ids, axis=2) - - # Recover the beam index by floor division. - topk_beam_indices = topk_indices // vocab_size - # Gather 2*k top beams. - # --> [batch, 2*beams, length] - topk_seq = gather_beams( - state.live_seqs, topk_beam_indices, batch_size, beam_size, beams_to_keep - ) - # Update sequences for the 2*K top-k new sequences. - # --> [batch, 2*beams, length] - topk_seq = lax.dynamic_update_slice( - topk_seq, topk_ids, (0, 0, state.cur_index + 1) - ) - - # Update LIVE (in-progress) sequences: - # Did any of these sequences reach an end marker? - # --> [batch, 2*beams] - newly_finished = topk_seq[:, :, state.cur_index + 1] == end_marker - # To prevent these newly finished sequences from being added to the LIVE - # set of active beam search sequences, set their log probs to a very large - # negative value. - new_log_probs = topk_log_probs + newly_finished * NEG_INF - # Determine the top k beam indices (from top 2*k beams) from log probs. - # --> [batch, beams] - _, new_topk_indices = lax.top_k(new_log_probs, k=beam_size) - new_topk_indices = jnp.flip(new_topk_indices, axis=1) - # Gather the top k beams (from top 2*k beams). - # --> [batch, beams, length], [batch, beams] - top_alive_seq, top_alive_log_probs = gather_beams( - [topk_seq, new_log_probs], - new_topk_indices, - batch_size, - 2 * beam_size, - beam_size, - ) - - # Determine the top k beam indices from the original set of all beams. - # --> [batch, beams] - top_alive_indices = gather_beams( - topk_beam_indices, - new_topk_indices, - batch_size, - 2 * beam_size, - beam_size, - ) - # With these, gather the top k beam-associated caches. - # --> {[batch, beams, ...], ...} - top_alive_cache = cache_gather_beams( - new_cache, - top_alive_indices, - batch_size, - beam_size, - beam_size, - True, - cache_offset, - ) - - # Update FINISHED (reached end of sentence) sequences: - # Calculate new seq scores from log probabilities. - lengths = state.cur_index + 1 - # We should add the lengths of the prompts to the beams as well to - # calculate the brevity penalty correctly. - # initial_index --> [batch_size,] - # topk_lengths --> [batch_size, 2*beams] - topk_lengths = jnp.repeat(initial_index[:, None], beams_to_keep, axis=1) - # lengths is now: [batch_size, 2*beams] - lengths = topk_lengths + lengths - - new_scores = topk_log_probs / brevity_penalty(alpha, lengths) # pytype: disable=wrong-arg-types # jax-devicearray - # Mask out the still unfinished sequences by adding large negative value. - # --> [batch, 2*beams] - new_scores += (~newly_finished) * NEG_INF - - # Combine sequences, scores, and flags along the beam dimension and compare - # new finished sequence scores to existing finished scores and select the - # best from the new set of beams. - finished_seqs = jnp.concatenate( # --> [batch, 3*beams, length] - [state.finished_seqs, topk_seq], axis=1 - ) - finished_scores = jnp.concatenate( # --> [batch, 3*beams] - [state.finished_scores, new_scores], axis=1 - ) - finished_flags = jnp.concatenate( # --> [batch, 3*beams] - [state.finished_flags, newly_finished], axis=1 - ) - # --> [batch, beams, length], [batch, beams], [batch, beams] - top_finished_seq, top_finished_scores, top_finished_flags = ( - gather_topk_beams( - [finished_seqs, finished_scores, finished_flags], - finished_scores, - batch_size, - beam_size, - ) - ) - - return BeamState( - cur_index=state.cur_index + 1, - live_logprobs=top_alive_log_probs, - finished_scores=top_finished_scores, - live_seqs=top_alive_seq, - finished_seqs=top_finished_seq, - finished_flags=top_finished_flags, - cache=top_alive_cache, - initial_index=initial_index, - ) - - # Run while loop and get final beam search state. - final_state = lax.while_loop( - beam_search_loop_cond_fn, beam_search_loop_body_fn, beam_search_init_state - ) - - # Account for the edge-case where there are no finished sequences for a - # particular batch item. If so, return live sequences for that batch item. - # --> [batch] - any_finished = jnp.any(final_state.finished_flags, axis=1) - # --> [batch, beams, length] - finished_seqs = jnp.where( - any_finished[:, None, None], - final_state.finished_seqs, - final_state.live_seqs, - ) - # --> [batch, beams] - finished_scores = jnp.where( - any_finished[:, None], - final_state.finished_scores, - final_state.live_logprobs, - ) - - # Construct the finished sequences back from the prompts that we kept - # separately in the right aligned buffer. - if right_aligned_input is not None: - # Right now we have right aligned inputs, and then the last tokens + - # completions in finished_seqs. We need to concatenate and get rid of the - # extra padding, while broadcasting in the beam dimension. - - # Drop the first token, because it is also in the `right_aligned_input` - # [batch, beams, length] - finished_seqs = finished_seqs[:, :, 1:] - # right_aligned_input is [batch, length_prompt], we need to create a new - # beams dimension and broadcast it along that. - # --> [batch, beams, length] - right_aligned_input = jnp.broadcast_to( - right_aligned_input[:, None, :], - (batch_size, finished_seqs.shape[1], right_aligned_input.shape[-1]), - ) - # Now concatenate along the length dimension. - # --> [batch, beams, length] - finished_seqs = jnp.concatenate( - [right_aligned_input, finished_seqs], axis=-1 - ) - - # Now we left align everything. - - # First flatten to [batch_size * beams, length] - flat_finished_seqs = jnp.reshape( - finished_seqs, (-1, finished_seqs.shape[-1]) - ) - # Left align everything. - flat_finished_seqs = _left_align_prompts(flat_finished_seqs) - # Shape back to the original shape. - left_aligned_seqs = jnp.reshape(flat_finished_seqs, finished_seqs.shape) - # Cut to the desired length (-1 because we added 1 right off the bat) - finished_seqs = left_aligned_seqs[:, :, : max_decode_len - 1] - else: - # Just drop the first dummy 0 token. - finished_seqs = finished_seqs[:, :, 1:] - - return finished_seqs, finished_scores diff --git a/t5x-main/t5x/decoding_test.py b/t5x-main/t5x/decoding_test.py deleted file mode 100644 index ebac83a67c3cf7d35e928e68d04efd3182e4008e..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/decoding_test.py +++ /dev/null @@ -1,1476 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for t5x.decoding.""" - -import functools -from typing import Mapping, Tuple -from unittest import mock - -from absl.testing import absltest -from absl.testing import parameterized -import jax -from jax.experimental import io_callback -import jax.numpy as jnp -import numpy as np -from t5x import decoding - -PAD_ID = 0 -EOS_ID = 1 -NEG_INF = decoding.NEG_INF - - -class DecodeTest(parameterized.TestCase): - - def test_temperature_sample_uneven_prefix(self): - def token_to_logits(decoding_state: decoding.DecodingState): - del decoding_state - # Always sample id 2 for batch element 0 and id 3 for element 1. - logits = np.array( - [[-1e7, -1e7, 0, -1e7], [-1e7, -1e7, -1e7, 0]], dtype=np.float32 - ) - return logits, {} - - inputs = np.array([[0, 5, 7, 1, 0, 0], [0, 6, 1, 0, 0, 0]]) - sampled_sequences, _ = decoding._temperature_sample_single_trial( - inputs, - {}, - token_to_logits, - EOS_ID, - jax.random.PRNGKey(0), - topk=0, - initial_index=np.array([3, 2]), - ) - expected = np.array([[5, 7, 1, 2, 2, 2], [6, 1, 3, 3, 3, 3]]) - np.testing.assert_array_equal(expected, sampled_sequences) - - def test_temperature_sample_no_prefix(self): - batch, max_decode_len = 2, 3 - - def token_to_logits(decoding_state: decoding.DecodingState): - del decoding_state - # Always sample id 2 for batch element 0 and id 3 for element 1. - logits = np.array( - [[-1e7, -1e7, 0, -1e7], [-1e7, -1e7, -1e7, 0]], dtype=np.float32 - ) - return logits, {} - - inputs = np.zeros((batch, max_decode_len), dtype=np.int32) - sampled_sequences, _ = decoding._temperature_sample_single_trial( - inputs, {}, token_to_logits, EOS_ID, jax.random.PRNGKey(0), topk=0 - ) - - expected = [[2, 2, 2], [3, 3, 3]] - np.testing.assert_array_equal(expected, sampled_sequences) - - def test_temperature_sample_prefix(self): - def token_to_logits(decoding_state: decoding.DecodingState): - del decoding_state - # Always sample id 2 for batch element 0 and id 3 for element 1. - logits = np.array( - [[-1e7, -1e7, 0, -1e7], [-1e7, -1e7, -1e7, 0]], dtype=np.float32 - ) - return logits, {} - - # batch element 0 has length 3 prefix and element 1 has length 2. - inputs = np.array([[0, 5, 6, 7, 0], [0, 8, 9, 0, 0]], dtype=np.int32) - sampled_sequences, _ = decoding._temperature_sample_single_trial( - inputs, {}, token_to_logits, EOS_ID, jax.random.PRNGKey(0), topk=0 - ) - - expected = [[5, 6, 7, 2, 2], [8, 9, 3, 3, 3]] - np.testing.assert_array_equal(expected, sampled_sequences) - - def test_temperature_sample_with_zero_temperature(self): - batch, max_decode_len = 2, 3 - - def token_to_logits(decoding_state: decoding.DecodingState): - del decoding_state - # Use very large logits that are close to one another. - logits = np.array( - [[1700.47, 1700.48, 1700.51, 1700.45], [3.2, 4.8, -5.3, 5.6]], - dtype=np.float32, - ) - return logits, {} - - inputs = np.zeros((batch, max_decode_len), dtype=np.int32) - sampled_sequences, _ = decoding._temperature_sample_single_trial( - inputs, - {}, - token_to_logits, - EOS_ID, - jax.random.PRNGKey(0), - topk=4, - temperature=0.0, - ) - - expected = [[2, 2, 2], [3, 3, 3]] - np.testing.assert_array_equal(expected, sampled_sequences) - - def test_temperature_sample_prefix_ending_with_eos(self): - def token_to_logits(decoding_state: decoding.DecodingState): - del decoding_state - # Always sample id 2 for batch element 0 and id 3 for element 1. - logits = np.array( - [[-1e7, -1e7, 0, -1e7], [-1e7, -1e7, -1e7, 0]], dtype=np.float32 - ) - return logits, {} - - # batch element 0 has length 4 prefix (including the initial dummy token and - # the last eos) and element 1 has length 3. - inputs = np.array([[0, 5, 6, 1, 0], [0, 8, 1, 0, 0]], dtype=np.int32) - sampled_sequences, _ = decoding._temperature_sample_single_trial( - inputs, {}, token_to_logits, EOS_ID, jax.random.PRNGKey(0), topk=1 - ) - - expected = [[5, 6, 1, 2, 2], [8, 1, 3, 3, 3]] - np.testing.assert_array_equal(expected, sampled_sequences) - - def test_temperature_sample_with_state_callback(self): - def token_to_logits(decoding_state: decoding.DecodingState): - del decoding_state - # A distribution with roughly all probability mass in sample id 3 - logits = np.array( - [[-1e7, -1e7, -1e7, 0], [-1e7, -1e7, -1e7, 0]], dtype=np.float32 - ) - return logits, {} - - def state_callback_fn(state): - def callback_fn(current_index_and_sequences): - """Add EOS token after first time token id 3 has been sampled.""" - current_index, sequences = current_index_and_sequences - sequences = np.array(sequences) - for i in range(len(current_index)): - if sequences[i, current_index[i]] == 3: - sequences[i, current_index[i] + 1] = EOS_ID - return sequences - - sequences = io_callback( - callback_fn, - jax.ShapeDtypeStruct(state.sequences.shape, state.sequences.dtype), - (state.cur_index, state.sequences), - ) - return state.replace(sequences=sequences) - - inputs = np.array([[0, 5, 6, 7, 0], [0, 8, 9, 0, 0]], dtype=np.int32) - sampled_sequences, _ = decoding._temperature_sample_single_trial( - inputs, - {}, - token_to_logits, - EOS_ID, - jax.random.PRNGKey(0), - topk=0, - temperature=0.0, - state_callback_fn=state_callback_fn, - ) - - expected = [[5, 6, 7, 3, EOS_ID], [8, 9, 3, EOS_ID, 0]] - np.testing.assert_array_equal(expected, sampled_sequences) - - def test_temperature_sample_with_logit_callback(self): - def token_to_logits(decoding_state: decoding.DecodingState): - del decoding_state - # uniform distribution over targets from model - logits = np.array( - [[-1e7, -1e7, -1e7, -1e7], [-1e7, -1e7, -1e7, -1e7]], dtype=np.float32 - ) - return logits, {} - - def logit_callback_fn(logits, state): - del state # unused - # Rewrite logits to always sample id 2 for batch element 0 and - # id 3 for element 1. - logits[0, 2] = 0 - logits[1, 3] = 0 - return logits - - # batch element 0 has length 3 prefix and element 1 has length 2. - inputs = np.array([[0, 5, 6, 7, 0], [0, 8, 9, 0, 0]], dtype=np.int32) - sampled_sequences, _ = decoding._temperature_sample_single_trial( - inputs, - {}, - token_to_logits, - EOS_ID, - jax.random.PRNGKey(0), - topk=0, - temperature=0.0, - logit_callback_fn=logit_callback_fn, - ) - - expected = [[5, 6, 7, 2, 2], [8, 9, 3, 3, 3]] - np.testing.assert_array_equal(expected, sampled_sequences) - - def test_temperature_sample_prefix_ending_with_eos_early_stop(self): - batch, max_decode_len = 2, 7 - rng0 = jax.random.PRNGKey(0) - - ret = [np.array([2, 3]) for _ in range(max_decode_len)] - # Sequence 1 outputs EOS=1 when i = 3 where `i` is the while loop counter of - # `decoding._temperature_sample_single_trial`. - ret[3] = np.array([2, 1]) - # Sequence 0 outputs EOS=1 when i = 4. - ret[4] = np.array([1, 3]) - ret = jax.numpy.array(ret) - - def mocked_categorical(rng_input, logits): # pylint: disable=unused-argument - """Ignores logit and returns only based on the rng_input.""" - rng = rng0 - k = 0 - # Mimic the rng split done in `decoding.sample_loop_body_fn`. - for j in range(max_decode_len): - rng1, rng = jax.random.split(rng) - # We want to sift out `j` for which rng1 == rng_input - k += j * (rng1 == rng_input).all() - # `k` at this point is equal to the while loop variable `i` of the caller. - return ret[k] - - def token_to_logits(decoding_state: decoding.DecodingState): - del decoding_state - # These values are not used in this test because random.categorical is - # directly mocked. - dummy_logits = np.zeros((batch, 4), dtype=np.float32) - return dummy_logits, {} - - inputs = np.array( - [[0, 5, 1, 0, 0, 0, 0], [0, 8, 0, 0, 0, 0, 0]], dtype=np.int32 - ) - with mock.patch.object(jax.random, 'categorical', new=mocked_categorical): - sampled_sequences, _ = decoding._temperature_sample_single_trial( - inputs, {}, token_to_logits, EOS_ID, rng0, topk=0 - ) - - expected = [[5, 1, 2, 2, 1, 0, 0], [8, 3, 3, 1, 0, 0, 0]] - np.testing.assert_array_equal(expected, sampled_sequences) - - def test_greedy_decoding_topk_sample_log_probs(self): - def token_to_logits(decoding_state: decoding.DecodingState): - del decoding_state - # Sample [2, 3] with probability [0.6, 0.4]. - logits = np.array( - [[-1e7, -1e7, -0.510825624, -0.916290732]], dtype=np.float32 - ) - return logits, {} - - inputs = np.array([[0, 2, 2, 2, 0]], dtype=np.int32) - sampled_sequences, sampled_log_probs = ( - decoding._temperature_sample_single_trial( - inputs, - {}, - token_to_logits, - EOS_ID, - jax.random.PRNGKey(0), - topk=1, - rescale_log_probs=True, - ) - ) - - expected_sequence = [[2, 2, 2, 2, 2]] - expected_log_probs = [0.0] - np.testing.assert_array_equal(expected_sequence, sampled_sequences) - np.testing.assert_array_almost_equal(expected_log_probs, sampled_log_probs) - - inputs = np.array([[0, 2, 2, 3, 0]], dtype=np.int32) - sampled_sequences, sampled_log_probs = ( - decoding._temperature_sample_single_trial( - inputs, - {}, - token_to_logits, - EOS_ID, - jax.random.PRNGKey(0), - topk=1, - rescale_log_probs=False, - ) - ) - - expected_sequence = [[2, 2, 3, 2, 2]] - expected_log_probs = [-1.02165125] - np.testing.assert_array_equal(expected_sequence, sampled_sequences) - np.testing.assert_array_almost_equal(expected_log_probs, sampled_log_probs) - - def test_temperature_sample_log_prob(self): - batch, max_decode_len = 2, 7 - rng0 = jax.random.PRNGKey(0) - - ret = [np.array([2, 3]) for _ in range(max_decode_len)] - # Sequence 1 outputs EOS=1 when i = 3 where `i` is the while loop counter of - # `decoding._temperature_sample_single_trial`. - ret[3] = np.array([2, 1]) - # Sequence 0 outputs EOS=1 when i = 4. - ret[4] = np.array([1, 3]) - ret = jax.numpy.array(ret) - - # TODO(hwchung): refactor this. - def mocked_categorical(rng_input, logits): # pylint: disable=unused-argument - """Ignores logit and returns only based on the rng_input.""" - rng = rng0 - k = 0 - # Mimic the rng split done in `decoding.sample_loop_body_fn`. - for j in range(max_decode_len): - rng1, rng = jax.random.split(rng) - # We want to sift out `j` for which rng1 == rng_input - k += j * (rng1 == rng_input).all() - # `k` at this point is equal to the while loop variable `i` of the caller. - return ret[k] - - logits = np.random.randn(batch, 4) - token_to_logits = lambda decoding_state: (logits, {}) - inputs = np.array( - [[0, 5, 1, 0, 0, 0, 0], [0, 8, 0, 0, 0, 0, 0]], dtype=np.int32 - ) - with mock.patch.object(jax.random, 'categorical', new=mocked_categorical): - sampled_sequences, log_prob = decoding._temperature_sample_single_trial( - inputs, {}, token_to_logits, EOS_ID, rng0, topk=0 - ) - - log_probs = jax.nn.log_softmax(logits) - expected = [[5, 1, 2, 2, 1, 0, 0], [8, 3, 3, 1, 0, 0, 0]] - expected_log_prob = [ - log_probs[0, 2] + log_probs[0, 2] + log_probs[0, 1], - log_probs[1, 3] + log_probs[1, 3] + log_probs[1, 1], - ] - expected_log_prob = np.array(expected_log_prob) - np.testing.assert_array_equal(expected, sampled_sequences) - np.testing.assert_allclose(expected_log_prob, log_prob, atol=1e-5) - - def test_temperature_sample_num_decodes(self): - num_decodes = 3 - rng0 = jax.random.PRNGKey(0) - inputs = np.array([[0, 5, 1, 0], [0, 8, 7, 0]], dtype=np.int32) - - with mock.patch.object( - decoding, '_temperature_sample_single_trial' - ) as mocked: - # expanded_decodes: [batch * num_decodes, max_decode_len] - expanded_decodes = np.array([ - [5, 1, 4, 4], - [5, 1, 5, 5], - [5, 1, 3, 3], - [8, 7, 5, 5], - [8, 7, 3, 3], - [8, 7, 4, 4], - ]) - # expanded_log_prob: [batch * num_decodes] - expanded_log_prob = np.array([-2.3, -1.3, -3.6, -0.5, -2.5, -1.9]) - mocked.return_value = expanded_decodes, expanded_log_prob - - decodes, scores = decoding.temperature_sample( - inputs, {}, mock.Mock(), EOS_ID, rng0, num_decodes=num_decodes - ) - - expanded_inputs = jnp.array([ - [0, 5, 1, 0], - [0, 5, 1, 0], - [0, 5, 1, 0], - [0, 8, 7, 0], - [0, 8, 7, 0], - [0, 8, 7, 0], - ]) - # Test that the actual decode function is called with the expanded values. - np.testing.assert_array_equal(mocked.call_args[0][0], expanded_inputs) - - np.testing.assert_array_equal( - decodes, - [ - [[5, 1, 3, 3], [5, 1, 4, 4], [5, 1, 5, 5]], - [[8, 7, 3, 3], [8, 7, 4, 4], [8, 7, 5, 5]], - ], - ) - np.testing.assert_allclose(scores, [[-3.6, -2.3, -1.3], [-2.5, -1.9, -0.5]]) - - def test_temperature_sample_num_decodes_with_initial_index(self): - num_decodes = 3 - rng0 = jax.random.PRNGKey(0) - inputs = np.array([[0, 5, 1, 0], [0, 8, 7, 0]], dtype=np.int32) - initial_index = np.array([1, 2], dtype=np.int32) - - with mock.patch.object( - decoding, '_temperature_sample_single_trial' - ) as mocked: - with mock.patch.object(decoding, 'cache_map') as mocked_cache_map: - # expanded_decodes: [batch * num_decodes, max_decode_len] - expanded_decodes = np.array([ - [5, 1, 4, 4], - [5, 1, 5, 5], - [5, 1, 3, 3], - [8, 7, 5, 5], - [8, 7, 3, 3], - [8, 7, 4, 4], - ]) - # expanded_log_prob: [batch * num_decodes] - expanded_log_prob = np.array([-2.3, -1.3, -3.6, -0.5, -2.5, -1.9]) - mocked.return_value = expanded_decodes, expanded_log_prob - - decodes, scores = decoding.temperature_sample( - inputs, - {}, - mock.Mock(), - EOS_ID, - rng0, - num_decodes=num_decodes, - initial_index=initial_index, - ) - - expanded_inputs = jnp.array([ - [0, 5, 1, 0], - [0, 5, 1, 0], - [0, 5, 1, 0], - [0, 8, 7, 0], - [0, 8, 7, 0], - [0, 8, 7, 0], - ]) - expanded_initial_index = np.array([1, 1, 1, 2, 2, 2], dtype=np.int32) - # Test that the actual decode function is called with the expanded - # values. - np.testing.assert_array_equal(mocked.call_args[0][0], expanded_inputs) - np.testing.assert_array_equal( - mocked.call_args[1]['initial_index'], expanded_initial_index - ) - # Test that the function was applied to the index in the cache map - self.assertTrue(mocked_cache_map.call_args[1]['apply_to_index']) - - np.testing.assert_array_equal( - decodes, - [ - [[5, 1, 3, 3], [5, 1, 4, 4], [5, 1, 5, 5]], - [[8, 7, 3, 3], [8, 7, 4, 4], [8, 7, 5, 5]], - ], - ) - np.testing.assert_allclose(scores, [[-3.6, -2.3, -1.3], [-2.5, -1.9, -0.5]]) - - @parameterized.named_parameters( - dict( - testcase_name='no_initial_index', - initial_index=None, - expected_calls=6, - ), - dict( - testcase_name='initial_index', - initial_index=np.array([1, 2], dtype=np.int32), - expected_calls=4, - ), - dict( - testcase_name='lower_initial_index', - initial_index=np.array([1, 1], dtype=np.int32), - expected_calls=5, # we decode 4 tokens out of the prompt - ), - ) - def test_temperature_sample_max_decode_steps_with_initial_index( - self, initial_index, expected_calls - ): - max_decode_steps = 4 - rng0 = jax.random.PRNGKey(0) - inputs = np.array( - [[0, 2, 0, 0, 0, 0, 0, 0], [0, 2, 2, 0, 0, 0, 0, 0]], dtype=np.int32 - ) - - token_to_logits = mock.Mock() - token_to_logits.return_value = ( - np.array( - [[-1e7, -1e7, -1e7, 0], [-1e7, -1e7, -1e7, 0]], dtype=np.float32 - ), - {}, - ) - - # to unroll while loop - with jax.disable_jit(): - decodes, scores = decoding.temperature_sample( - inputs, - {}, - token_to_logits, - EOS_ID, - rng0, - initial_index=initial_index, - topk=4, - max_decode_steps=max_decode_steps, - ) - - self.assertLen(token_to_logits.call_args_list, expected_calls) - - expected_output = np.array( - [[2, 3, 3, 3, 3, 0, 0, 0], [2, 2, 3, 3, 3, 3, 0, 0]] - ) - expected_output = jnp.expand_dims(expected_output, 1) - - np.testing.assert_array_equal(decodes, expected_output) - np.testing.assert_array_equal(scores, [[0.0], [0.0]]) - - def test_temperature_sample_max_decode_steps_endpad(self): - max_decode_steps = 4 - rng0 = jax.random.PRNGKey(0) - inputs = np.array( - [ - [0, 2, 0, 0, 0, 0, 0, 0], - [0, 2, 2, 2, 2, 2, 2, 0], - [0, 2, 2, 2, 0, 0, 0, 0], - ], - dtype=np.int32, - ) - initial_index = np.array([1, 6, 0]) - - token_to_logits = mock.Mock() - token_to_logits.return_value = ( - np.array( - [ - [-1e7, -1e7, -1e7, 0], - [-1e7, -1e7, -1e7, 0], - [-1e7, -1e7, -1e7, 0], - ], - dtype=np.float32, - ), - {}, - ) - - # to unroll while loop - with jax.disable_jit(): - decodes, scores = decoding.temperature_sample( - inputs, - {}, - token_to_logits, - EOS_ID, - rng0, - initial_index=initial_index, - topk=4, - max_decode_steps=max_decode_steps, - ) - - # `inputs[2]` starts from index 0. So it requires 3 calls to - # `token_to_logits` to exit the prompt (these generated tokens are - # overridden) and 4 more calls to fill the rest. `inputs[0]` only need 4 - # calls. In the last 3 calls, it generates but MUST NOT populate the - # sequences because it is already ended. - self.assertLen(token_to_logits.call_args_list, 7) - expected_output = np.array( - [ - [2, 3, 3, 3, 3, 0, 0, 0], - [2, 2, 2, 2, 2, 2, 3, 3], - [2, 2, 2, 3, 3, 3, 3, 0], - ], - dtype=np.int32, - ) - expected_output = jnp.expand_dims(expected_output, 1) - - np.testing.assert_array_equal(decodes, expected_output) - np.testing.assert_allclose(scores, [[0.0], [0.0], [0.0]]) - - def test_temperature_sample_max_decode_steps_docstring_ex4(self): - max_decode_steps = 2 - rng0 = jax.random.PRNGKey(0) - inputs = np.array( - [[0, 2, 0, 0, 0, 0, 0, 0], [0, 3, 4, 0, 0, 0, 0, 0]], dtype=np.int32 - ) - initial_index = np.array([1, 2]) - - token_to_logits = mock.Mock() - token_to_logits.return_value = ( - np.array( - [[-1e7, -1e7, 0, -1e7], [-1e7, -1e7, -1e7, 0]], dtype=np.float32 - ), - {}, - ) - - # to unroll while loop - with jax.disable_jit(): - decodes, _ = decoding.temperature_sample( - inputs, - {}, - token_to_logits, - EOS_ID, - rng0, - initial_index=initial_index, - topk=4, - max_decode_steps=max_decode_steps, - ) - self.assertLen(token_to_logits.call_args_list, 2) - expected_output = np.array( - [[2, 2, 2, 0, 0, 0, 0, 0], [3, 4, 3, 3, 0, 0, 0, 0]], dtype=np.int32 - ) - expected_output = jnp.expand_dims(expected_output, 1) - - np.testing.assert_array_equal(decodes, expected_output) - - def test_temperature_sample_max_decode_steps_hard_limit(self): - max_decode_steps = 10 - max_decode_steps_hard_limit = 4 - rng0 = jax.random.PRNGKey(0) - inputs = np.array( - [[0, 2, 0, 0, 0, 0, 0, 0], [0, 2, 2, 0, 0, 0, 0, 0]], dtype=np.int32 - ) - - token_to_logits = mock.Mock() - token_to_logits.return_value = ( - np.array( - [[-1e7, -1e7, -1e7, 0], [-1e7, -1e7, -1e7, 0]], dtype=np.float32 - ), - {}, - ) - - # to unroll while loop - with jax.disable_jit(): - decodes, scores = decoding.temperature_sample( - inputs, - {}, - token_to_logits, - EOS_ID, - rng0, - topk=4, - max_decode_steps=max_decode_steps, - max_decode_steps_hard_limit=max_decode_steps_hard_limit, - ) - - expected_output = np.array( - [[2, 3, 3, 3, 3, 0, 0, 0], [2, 2, 3, 3, 3, 3, 0, 0]] - ) - expected_output = jnp.expand_dims(expected_output, 1) - - np.testing.assert_array_equal(decodes, expected_output) - np.testing.assert_array_equal(scores, [[0.0], [0.0]]) - - def test_temperature_sample_topp(self): - rng0 = jax.random.PRNGKey(0) - inputs = np.zeros((1, 20), dtype=np.int32) - - token_to_logits = mock.Mock() - - # logits correspond to (0.3, 0, 0.1, 0.6) - token_to_logits.return_value = ( - np.array([[-1.2, -1e7, -2.3, -0.51]], dtype=np.float32), - {}, - ) - - decodes, scores = decoding.temperature_sample( - inputs, {}, token_to_logits, EOS_ID, rng0, topp=0.55, topk=0 - ) # anything under 0.6 will trigger deterministic decoding. - - expected_output = np.array([[3] * 20]) - expected_output = jnp.expand_dims(expected_output, 1) - - np.testing.assert_array_equal(decodes, expected_output) - np.testing.assert_array_equal(scores, [[0.0]]) - - # temperature is applied first, so the distribution becomes - # (0.27, 0, 0.069, 0.65), so if topp is 0.63, it should become greedy. - decodes, scores = decoding.temperature_sample( - inputs, - {}, - token_to_logits, - EOS_ID, - rng0, - temperature=0.8, - topp=0.63, - topk=0, - ) - - expected_output = np.array([[3] * 20]) - expected_output = jnp.expand_dims(expected_output, 1) - - np.testing.assert_array_equal(decodes, expected_output) - np.testing.assert_array_equal(scores, [[0.0]]) - - def test_temperature_sample_per_item_temperature(self): - rng0 = jax.random.PRNGKey(0) - - # 4 batches of 20 sequence length. - batch_size = 4 - num_decodes = 2 - seq_length = 20 - inputs = np.zeros((batch_size, seq_length), dtype=np.int32) - token_to_logits = mock.Mock() - token_to_logits.return_value = ( - np.repeat( - np.array( - # First batch logits correspond to (0.3, 0, 0.1, 0.6). - # The rest of the batches just change the order. - [ - [-1.2, -1e7, -2.3, -0.51], - [-0.51, -1.2, -1e7, -2.3], - [-1.2, -1e7, -0.51, -2.3], - [-1.2, -0.51, -1e7, -2.3], - ], - dtype=np.float32, - ), - num_decodes, - axis=0, - ), - {}, - ) - - # temperature is applied first, so if topp is 0.63 and temperature < 0.8, - # it should become greedy. - decodes, scores = decoding.temperature_sample( - inputs, - {}, - token_to_logits, - EOS_ID, - rng0, - temperature=np.array([0.5, 0, 0.3, 0.2]), - topp=0.63, - topk=0, - num_decodes=num_decodes, - ) - - # Last batch item ends in 0s because 1 is EOS ID. - expected_output = np.array([ - [3] * seq_length, - [0] * seq_length, - [2] * seq_length, - [1] + [0] * (seq_length - 1), - ]) - # Expand number of decodes dimension. - expected_output = jnp.expand_dims(expected_output, 1) - expected_output = jnp.repeat(expected_output, num_decodes, axis=1) - - np.testing.assert_array_equal(decodes, expected_output) - np.testing.assert_array_equal(scores, [[0.0] * num_decodes] * batch_size) - - def test_dynamic_topp_max_decode_steps(self): - rng0 = jax.random.PRNGKey(0) - inputs = np.zeros((1, 20), dtype=np.int32) - - token_to_logits = mock.Mock() - - # logits correspond to (0.3, 0, 0.1, 0.6) - token_to_logits.return_value = ( - np.array([[-1.2, -1e7, -2.3, -0.51]], dtype=np.float32), - {}, - ) - - def dynamic_decode_fn(inputs, temperature, topp, max_decode_steps): - return decoding.temperature_sample( - inputs, - {}, - token_to_logits, - EOS_ID, - rng0, - temperature=temperature, - topp=topp, - topk=0, - max_decode_steps=max_decode_steps, - ) - - dynamic_decode_fn_jit = jax.jit(dynamic_decode_fn) - - decodes, scores = dynamic_decode_fn_jit(inputs, 0.8, 0.63, 10) - - expected_output = np.array([[3] * 10 + [0] * 10]) - expected_output = jnp.expand_dims(expected_output, 1) - - np.testing.assert_array_equal(decodes, expected_output) - np.testing.assert_array_equal(scores, [[0.0]]) - - def test_topp_log_probs(self): - rng0 = jax.random.PRNGKey(0) - inputs = np.zeros((1, 1), dtype=np.int32) - - token_to_logits = mock.Mock() - - # logits correspond to (0.3, 0, 0.1, 0.6) - token_to_logits.return_value = ( - np.array([[-1.2, NEG_INF, -2.3, -0.51]], dtype=np.float32), - {}, - ) - - with jax.disable_jit(): - # this lets us see logits after topp and topk are applied - with mock.patch.object(jax.random, 'categorical') as mocked: - mocked.return_value = jnp.array([0], dtype=jnp.int32) - decodes, _ = decoding.temperature_sample( - inputs, - {}, - token_to_logits, - EOS_ID, - rng0, - temperature=1.4, - topp=0.7, - topk=0, - ) - - self.assertLen(token_to_logits.call_args_list, 1) - np.testing.assert_array_equal(decodes, jnp.asarray([[[0]]])) - - np.testing.assert_array_almost_equal( - mocked.call_args_list[0][0][1], - jnp.asarray([[-0.85714293, NEG_INF, NEG_INF, -0.36428571]]), - ) - - def test_add_beam_dim(self): - x = np.array([[0, 5, 1, 0], [0, 8, 6, 9]], dtype=np.int32) - y = decoding.add_beam_dim(x, beam_size=3) - self.assertEqual(y.shape, (2, 3, 4)) - np.testing.assert_array_equal( - [ - [[0, 5, 1, 0], [0, 5, 1, 0], [0, 5, 1, 0]], - [[0, 8, 6, 9], [0, 8, 6, 9], [0, 8, 6, 9]], - ], - y, - ) - - def test_flat_batch_beam_expand(self): - x = np.array([[0, 5, 1, 0], [0, 8, 6, 9]], dtype=np.int32) - np.testing.assert_array_equal( - [[0, 5, 1, 0], [0, 5, 1, 0], [0, 8, 6, 9], [0, 8, 6, 9]], - decoding.flat_batch_beam_expand(x, beam_size=2), - ) - - def test_top_k_two_stage(self): - def _test_top_k(batch_size, k): - # Pick sufficiently large seq_len. - seq_len = 2047 * k * batch_size - seq = np.arange(seq_len) - np.random.shuffle(seq) - x = jnp.reshape(seq, (batch_size, int(seq_len / batch_size))).astype( - jnp.float32 - ) - np.testing.assert_almost_equal( - decoding.top_k_two_stage(x, k), jax.lax.top_k(x, k), decimal=5 - ) - - # Test small batch cases (batch={1,8}, k=16). - _test_top_k(1, 16) - _test_top_k(8, 16) - # Test large batch cases (batch={9,32}, k=11). - _test_top_k(9, 11) - _test_top_k(32, 11) - - def test_cache_map(self): - cache = { - 'layers_0': { - 'cached_key': jnp.ones([3, 6]), - 'cached_values': jnp.ones([3, 6]), - 'cache_index': jnp.ones([ - 3, - ]), - }, - 'layers_1': { - 'self_attention': { - 'cached_key': jnp.ones([2, 7]), - 'cached_values': jnp.ones([5, 8]), - 'cache_index': jnp.array(1), - }, - 'encoder_decoder_attention': { - 'cached_key': jnp.ones([10, 12, 2]), - 'cached_values': jnp.ones([4, 7, 2]), - 'cache_index': jnp.ones([4, 5, 6]), - }, - }, - } - - fn = functools.partial(jnp.add, 4) - - gold_cache = { - 'layers_0': { - 'cached_key': fn(jnp.ones([3, 6])), - 'cached_values': fn(jnp.ones([3, 6])), - 'cache_index': jnp.ones([ - 3, - ]), - }, - 'layers_1': { - 'self_attention': { - 'cached_key': fn(jnp.ones([2, 7])), - 'cached_values': fn(jnp.ones([5, 8])), - 'cache_index': jnp.array(1), - }, - 'encoder_decoder_attention': { - 'cached_key': fn(jnp.ones([10, 12, 2])), - 'cached_values': fn(jnp.ones([4, 7, 2])), - 'cache_index': jnp.ones([4, 5, 6]), - }, - }, - } - - jax.tree.map( - np.testing.assert_array_equal, decoding.cache_map(fn, cache), gold_cache - ) - - def test_cache_map_with_index(self): - cache = { - 'layers_0': { - 'cached_key': jnp.ones([3, 6]), - 'cached_values': jnp.ones([3, 6]), - 'cache_index': jnp.ones([ - 3, - ]), - }, - 'layers_1': { - 'relpos_bias': { - 'cached_bias': jnp.ones([1, 5, 3]), - }, - 'self_attention': { - 'cached_key': jnp.ones([2, 7]), - 'cached_values': jnp.ones([5, 8]), - 'cache_index': jnp.array(1), - }, - 'encoder_decoder_attention': { - 'cached_key': jnp.ones([10, 12, 2]), - 'cached_values': jnp.ones([4, 7, 2]), - 'cache_index': jnp.ones([4, 5, 6]), - }, - }, - 'position_embedder': { - 'position_embedder_index': jnp.array([-1]), - }, - } - - fn = functools.partial(jnp.add, 8) - - gold_cache = { - 'layers_0': { - 'cached_key': fn(jnp.ones([3, 6])), - 'cached_values': fn(jnp.ones([3, 6])), - 'cache_index': fn( - jnp.ones([ - 3, - ]) - ), - }, - 'layers_1': { - 'relpos_bias': { - 'cached_bias': jnp.ones([1, 5, 3]), - }, - 'self_attention': { - 'cached_key': fn(jnp.ones([2, 7])), - 'cached_values': fn(jnp.ones([5, 8])), - 'cache_index': fn(jnp.array(1)), - }, - 'encoder_decoder_attention': { - 'cached_key': fn(jnp.ones([10, 12, 2])), - 'cached_values': fn(jnp.ones([4, 7, 2])), - 'cache_index': fn(jnp.ones([4, 5, 6])), - }, - }, - 'position_embedder': { - 'position_embedder_index': jnp.array([-1]), - }, - } - - jax.tree.map( - np.testing.assert_array_equal, - decoding.cache_map(fn, cache, apply_to_index=True), - gold_cache, - ) - - def test_beam_search(self): - # Toy problem, we have 4 states, A, B, START, END, (plus PAD). - # Scores are given by a first-order Markov model. - batch_size = 2 - beam_size = 2 - # PAD doesn't matter for this test, but part of the contract for beam_search - # is giving the PAD token id 0. - states = ['PAD', 'A', 'B', 'START-', '-END'] - num_states = len(states) - decode_length = 7 - - # Edge potentials (written inside edges for diagonals): - # 1 -1 1 -1 - # A ---- A ---- A ---- A ---- A - # 0 \ -1 \ 1 \ -1 \ 1 0 - # START X X X X END - # 0 / -1 / 1 / -1 / 1 0 - # B ---- B ---- B ---- B ---- B - # 1 -1 1 -1 - - # put the above edge potentials in a 3-tensor - ab_edge_potentials = np.asarray([ - [[1, -1], [-1, 1]], - [[-1, 1], [1, -1]], - [[1, -1], [-1, 1]], - [[-1, 1], [1, -1]], - ]) - # now we have to add on the START, END states - # and PAD at 0 - edge_potentials = np.ones([6, 5, 5]) * NEG_INF - edge_potentials[1:5, 1:3, 1:3] = ab_edge_potentials - # START can go to either A or B for free at t0 - edge_potentials[0, 3, 1] = 0 - edge_potentials[0, 3, 2] = 0 - # either A or B can go to END for free at t5 - edge_potentials[5, 1, 4] = 0 - edge_potentials[5, 2, 4] = 0 - # PAD can go to anything for free (doesn't matter for this test) - edge_potentials[:, 0, :] = 0 - - edge_potentials = jnp.asarray(edge_potentials) - - # at time 0, we start with state=START=3 - logits0 = jnp.asarray([NEG_INF, NEG_INF, NEG_INF, 0, NEG_INF]) - - # add dummy flattened batch x beam dim for broadcasting - logits0 = jnp.expand_dims(logits0, axis=0) - edge_potentials = jnp.expand_dims(edge_potentials, axis=0) - - def tokens_to_logits( - decoding_state: decoding.DecodingState, - ) -> Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]: - token_indices = decoding_state.cur_token - state_cache = decoding_state.cache - cur_iter = state_cache['cur_iter'] - # grab edge potentials for the current timestep - cur_edge_potentials = jnp.take_along_axis( - edge_potentials, - jnp.reshape( - jnp.maximum(0, cur_iter[:, 0].astype(jnp.int32) - 1), - (batch_size * beam_size, 1, 1, 1), - ), - axis=1, - ) - cur_edge_potentials = jnp.squeeze(cur_edge_potentials, axis=1) - # get "logits" from edge potentials for requested tokens (except at t0) - cur_logits = jnp.matmul( - jnp.reshape( - jax.nn.one_hot(token_indices, num_states, axis=1), - (batch_size * beam_size, 1, num_states), - ), - cur_edge_potentials, - ) - cur_logits = jnp.squeeze(cur_logits, axis=1) - # use our START-only logits for t0, otherwise use the edge potentials - logits_for_tokens = jnp.where(cur_iter == 0, logits0, cur_logits) - # update state in the cache - new_cache = state_cache.copy() - new_cache['cur_iter'] = cur_iter + 1 - return logits_for_tokens, new_cache - - init_cache = {} - init_cache['cur_iter'] = jnp.zeros((batch_size, 1)) - - top_scoring, _ = decoding.beam_search( - inputs=np.zeros([batch_size, decode_length]), - cache=init_cache, - tokens_to_logits=tokens_to_logits, - eos_id=4, - num_decodes=beam_size, - alpha=0.0, - max_decode_len=decode_length, - ) - - # The two top scoring sequences should be a tie between - # START-AABBA-END - # and - # START-BBAAB-END - # (and greedy beam search will find both these with just two beams) - - top_scoring_strings = [ - ''.join(states[tok] for tok in top_scoring[0, i, :]) - for i in range(beam_size) - ] - - expected = ['START-AABBA-END', 'START-BBAAB-END'] - np.testing.assert_array_equal(expected, top_scoring_strings) - - def test_beam_search_force_decode_prefix(self): - beam_size = 2 - - def token_to_logits(decoding_state: decoding.DecodingState): - del decoding_state - # Use id 2 then 3 for batch element 0 and id 3 then 2 for element 1. - logits = np.repeat( - np.expand_dims( - np.array( - [ - [-1e7, -1e10, -0.1, -0.9, -1e4, -1e4, -1e4, -1e4], - [-1e7, -1e10, -0.9, -0.1, -1e4, -1e4, -1e4, -1e4], - ], - dtype=np.float32, - ), - axis=1, - ), - [beam_size], - axis=1, - ) - logits = decoding.flatten_beam_dim(logits) - return logits, {} - - # batch element 0 has length 1 and element 1 has length 2. - inputs = np.array([[0, 7, 0, 0, 0], [0, 4, 5, 0, 0]], dtype=np.int32) - rolled_inputs = np.array([[7, 0, 0, 0, 0], [4, 5, 0, 0, 0]], dtype=np.int32) - beam_search_sequences, decoding_scores = decoding.beam_search( - inputs, {}, token_to_logits, EOS_ID, num_decodes=beam_size, alpha=0 - ) - - # Prefixes are forced depending on inputs. - # Beam search sequences and corresponding scores are in reverse order. - self.assertTrue(np.all(np.diff(decoding_scores) >= 0)) - expected = np.array( - [[[7, 3, 2, 2, 2], [7, 2, 2, 2, 2]], [[4, 5, 2, 3, 3], [4, 5, 3, 3, 3]]] - ) - np.testing.assert_array_equal(expected, beam_search_sequences) - - expected_scores = [] - batch_logits = np.array( - [ - [-1e7, -1e10, -0.1, -0.9, -1e4, -1e4, -1e4, -1e4], - [-1e7, -1e10, -0.9, -0.1, -1e4, -1e4, -1e4, -1e4], - ], - dtype=np.float32, - ) - for batch, logits, prompt in zip(expected, batch_logits, rolled_inputs): - beam_expected_scores = [] - for beam in batch: - log_probs = jax.nn.log_softmax(logits) - # Add them directly since they are static. - beam_scores = [] - for token, prompt_token in zip(beam, prompt): - if prompt_token != 0: - beam_scores.append(0) - elif token == PAD_ID: - beam_scores.append(0) - else: - beam_scores.append(log_probs[token]) - beam_expected_scores.append(sum(beam_scores)) - expected_scores.append(beam_expected_scores) - np.testing.assert_allclose(expected_scores, decoding_scores, atol=1e-5) - - def test_beam_search_force_decode_no_prefix(self): - beam_size = 2 - - def token_to_logits(decoding_state: decoding.DecodingState): - del decoding_state - # Use id 2 then 3 for batch element 0 and id 3 then 2 for element 1. - logits = np.repeat( - np.expand_dims( - np.array( - [[-1e7, -1e10, -0.1, -0.9], [-1e7, -1e10, -0.9, -0.1]], - dtype=np.float32, - ), - axis=1, - ), - [beam_size], - axis=1, - ) - logits = decoding.flatten_beam_dim(logits) - return logits, {} - - # No prefix is passed. - inputs = np.array([[0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], dtype=np.int32) - beam_search_sequences, decoding_scores = decoding.beam_search( - inputs, {}, token_to_logits, EOS_ID, num_decodes=beam_size - ) - - # Prefixes are forced depending on inputs. - # Beam search sequences and corresponding scores are in reverse order. - self.assertTrue(np.all(np.diff(decoding_scores) >= 0)) - expected = np.array( - [[[3, 2, 2, 2, 2], [2, 2, 2, 2, 2]], [[2, 3, 3, 3, 3], [3, 3, 3, 3, 3]]] - ) - np.testing.assert_array_equal(expected, beam_search_sequences) - - def test_align_prompt(self): - prompts = np.array( - [ - [0, 0, 0, 0, 0, 0, 0], - [1, 2, 3, 4, 5, 6, 7], - [0, 1, 0, 0, 0, 0, 0], - [0, 1, 2, 0, 0, 0, 0], - [0, 1, 2, 3, 0, 0, 0], - ], - dtype=np.int32, - ) - right_aligned_prompts = decoding._right_align_prompts(prompts) - left_aligned_prompts = decoding._left_align_prompts(right_aligned_prompts) - np.testing.assert_array_equal( - np.array( - [ - [0, 0, 0, 0, 0, 0, 0], - [1, 2, 3, 4, 5, 6, 7], - [0, 0, 0, 0, 0, 0, 1], - [0, 0, 0, 0, 0, 1, 2], - [0, 0, 0, 0, 1, 2, 3], - ], - dtype=np.int32, - ), - right_aligned_prompts, - ) - np.testing.assert_array_equal( - np.array( - [ - [0, 0, 0, 0, 0, 0, 0], - [1, 2, 3, 4, 5, 6, 7], - [1, 0, 0, 0, 0, 0, 0], - [1, 2, 0, 0, 0, 0, 0], - [1, 2, 3, 0, 0, 0, 0], - ], - dtype=np.int32, - ), - left_aligned_prompts, - ) - - def test_beam_search_force_decode_prefix_with_initial_index(self): - beam_size = 2 - - record_decoding_states = [] - - def token_to_logits(decoding_state: decoding.DecodingState): - # Record the decoding_state coming in. - # pdb.set_trace() - record_decoding_states.append(decoding_state) - - # Use id 2 then 3 for batch element 0 and id 3, 2 then EOS for element 1. - logits = np.repeat( - np.expand_dims( - np.array( - [ - [-1e7, -1e10, -0.1, -0.9, -1e4, -1e4, -1e4, -1e4], - [-1e7, -1.0, -0.9, -0.1, -1e4, -1e4, -1e4, -1e4], - ], - dtype=np.float32, - ), - axis=1, - ), - [beam_size], - axis=1, - ) - - logits = decoding.flatten_beam_dim(logits) - # Return the cache as-is. - return logits, decoding_state.cache - - # batch element 0 has length 1 and element 1 has length 2. - inputs = np.array([[0, 7, 0, 0, 0], [0, 4, 5, 0, 0]], dtype=np.int32) - batch_size = inputs.shape[0] - rolled_inputs = np.array([[7, 0, 0, 0, 0], [4, 5, 0, 0, 0]], dtype=np.int32) - initial_index = np.array([1, 2], dtype=np.int32) - REST_OF_THE_SHAPE = 1024 # dummy pylint: disable=invalid-name - dummy_cache = { - 'cached_bias': np.ones((1, REST_OF_THE_SHAPE), dtype=np.float32), - 'decoder/layers_0/self_attention/cached_key': np.ones( - (batch_size, REST_OF_THE_SHAPE), dtype=np.float32 - ), - 'decoder/layers_0/self_attention/cache_index': np.ones( - (batch_size,), dtype=np.float32 - ), - } - - # Since we are capturing the cache, etc. - with jax.disable_jit(): - beam_search_sequences, decoding_scores = decoding.beam_search( - inputs, - dummy_cache, - token_to_logits, - EOS_ID, - num_decodes=beam_size, - alpha=0, - initial_index=initial_index, - ) - - # pdb.set_trace() - - # Since we're sending in a decode prefix, the first tokens that should get - # decoded are the last tokens in the prompt - broadcasted to the beam size. - expected_first_tokens = np.array([[7], [7], [5], [5]], dtype=np.int32) - np.testing.assert_array_equal( - expected_first_tokens, record_decoding_states[0].cur_token - ) - - # Assert on the expected cache shapes that `token_to_logits` should see. - first_cache = record_decoding_states[0].cache - - # This shouldn't expand. - self.assertEqual( - dummy_cache['cached_bias'].shape, first_cache['cached_bias'].shape - ) - # These should expand. - self.assertEqual( - (batch_size * beam_size, REST_OF_THE_SHAPE), - first_cache['decoder/layers_0/self_attention/cached_key'].shape, - ) - self.assertEqual( - (batch_size * beam_size,), - first_cache['decoder/layers_0/self_attention/cache_index'].shape, - ) - - # Prefixes are forced depending on inputs. - # Beam search sequences and corresponding scores are in reverse order. - self.assertTrue(np.all(np.diff(decoding_scores) >= 0)) - expected = np.array( - [[[7, 3, 2, 2, 2], [7, 2, 2, 2, 2]], [[4, 5, 3, 1, 0], [4, 5, 1, 0, 0]]] - ) - np.testing.assert_array_equal(expected, beam_search_sequences) - - expected_scores = [] - batch_logits = np.array( - [ - [-1e7, -1e10, -0.1, -0.9, -1e4, -1e4, -1e4, -1e4], - [-1e7, -1.0, -0.9, -0.1, -1e4, -1e4, -1e4, -1e4], - ], - dtype=np.float32, - ) - for batch, logits, prompt in zip(expected, batch_logits, rolled_inputs): - beam_expected_scores = [] - for beam in batch: - log_probs = jax.nn.log_softmax(logits) - # Add them directly since they are static. - beam_scores = [] - for token, prompt_token in zip(beam, prompt): - if prompt_token != 0: - beam_scores.append(0) - elif token == PAD_ID: - beam_scores.append(0) - else: - beam_scores.append(log_probs[token]) - beam_expected_scores.append(sum(beam_scores)) - expected_scores.append(beam_expected_scores) - np.testing.assert_allclose(expected_scores, decoding_scores, atol=1e-5) - - def test_beam_search_min_log_prob(self): - beam_size = 2 - - def token_to_logits(decoding_state: decoding.DecodingState): - del decoding_state - # Use id 2 then 3 for batch element 0 and id 3 then 2 for element 1. - logits = np.repeat( - np.expand_dims( - np.array( - [[-1e7, -1e10, -0.1, -0.9], [-1e7, -1e10, -0.9, -0.1]], - dtype=np.float32, - ), - axis=1, - ), - [beam_size], - axis=1, - ) - logits = decoding.flatten_beam_dim(logits) - return logits, {} - - # No prefix is passed. - inputs = np.array([[0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], dtype=np.int32) - beam_search_sequences, decoding_scores = decoding.beam_search( - inputs, - {}, - token_to_logits, - EOS_ID, - num_decodes=beam_size, - min_log_prob=-0.05, - ) - - # Prefixes are forced depending on inputs. - # Beam search sequences and corresponding scores are in reverse order. - self.assertTrue(np.all(np.diff(decoding_scores) >= 0)) - expected = np.array( - [[[3, 0, 0, 0, 0], [2, 0, 0, 0, 0]], [[2, 0, 0, 0, 0], [3, 0, 0, 0, 0]]] - ) - np.testing.assert_array_equal(expected, beam_search_sequences) - - def test_beam_search_max_decode_step(self): - beam_size = 2 - - def token_to_logits(decoding_state: decoding.DecodingState): - del decoding_state - # Use id 2 then 3 for batch element 0 and id 3 then 2 for element 1. - logits = np.repeat( - np.expand_dims( - np.array( - [[-1e7, -1e10, -0.1, -0.9], [-1e7, -1e10, -0.9, -0.1]], - dtype=np.float32, - ), - axis=1, - ), - [beam_size], - axis=1, - ) - logits = decoding.flatten_beam_dim(logits) - return logits, {} - - # No prefix is passed. - inputs = np.array([[0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], dtype=np.int32) - beam_search_sequences, decoding_scores = decoding.beam_search( - inputs, - {}, - token_to_logits, - EOS_ID, - num_decodes=beam_size, - max_decode_step=2, - ) - - # Prefixes are forced depending on inputs. - # Beam search sequences and corresponding scores are in reverse order. - self.assertTrue(np.all(np.diff(decoding_scores) >= 0)) - expected = np.array( - [[[3, 2, 0, 0, 0], [2, 2, 0, 0, 0]], [[2, 3, 0, 0, 0], [3, 3, 0, 0, 0]]] - ) - np.testing.assert_array_equal(expected, beam_search_sequences) - - def test_beam_search_force_decode_prefix_with_initial_index_max_decode_step( - self, - ): - beam_size = 2 - - record_decoding_states = [] - - def token_to_logits(decoding_state: decoding.DecodingState): - # Record the decoding_state coming in. - # pdb.set_trace() - record_decoding_states.append(decoding_state) - - # Use id 2 then 3 for batch element 0 and id 3, 2 then EOS for element 1. - logits = np.repeat( - np.expand_dims( - np.array( - [ - [-1e7, -1e10, -0.1, -0.9, -1e4, -1e4, -1e4, -1e4], - [-1e7, -1.0, -0.9, -0.1, -1e4, -1e4, -1e4, -1e4], - ], - dtype=np.float32, - ), - axis=1, - ), - [beam_size], - axis=1, - ) - - logits = decoding.flatten_beam_dim(logits) - # Return the cache as-is. - return logits, decoding_state.cache - - # batch element 0 has length 1 and element 1 has length 2. - inputs = np.array([[0, 7, 0, 0, 0], [0, 4, 5, 0, 0]], dtype=np.int32) - batch_size = inputs.shape[0] - initial_index = np.array([1, 2], dtype=np.int32) - REST_OF_THE_SHAPE = 1024 # dummy pylint: disable=invalid-name - dummy_cache = { - 'cached_bias': np.ones((1, REST_OF_THE_SHAPE), dtype=np.float32), - 'decoder/layers_0/self_attention/cached_key': np.ones( - (batch_size, REST_OF_THE_SHAPE), dtype=np.float32 - ), - 'decoder/layers_0/self_attention/cache_index': np.ones( - (batch_size,), dtype=np.float32 - ), - } - - # Since we are capturing the cache, etc. - with jax.disable_jit(): - beam_search_sequences, decoding_scores = decoding.beam_search( - inputs, - dummy_cache, - token_to_logits, - EOS_ID, - num_decodes=beam_size, - alpha=0, - initial_index=initial_index, - max_decode_step=2, - ) - - # Prefixes are forced depending on inputs. - # Beam search sequences and corresponding scores are in reverse order. - self.assertTrue(np.all(np.diff(decoding_scores) >= 0)) - # batch element 0 failed to find any finished sequence as EOS has logits - # -1e10 << NEG_INF, it extends just two steps. batch element 1 found two - # finished sequences along the way. - expected = np.array( - [[[7, 3, 2, 0, 0], [7, 2, 2, 0, 0]], [[4, 5, 3, 1, 0], [4, 5, 1, 0, 0]]] - ) - np.testing.assert_array_equal(expected, beam_search_sequences) - - -if __name__ == '__main__': - absltest.main() diff --git a/t5x-main/t5x/disable_gc_during_import.py b/t5x-main/t5x/disable_gc_during_import.py deleted file mode 100644 index 470c6f103730bce9be02bc9d20221f21d69193d6..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/disable_gc_during_import.py +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Disables gc during each top-level import. - -Only takes effect when environment variable -EXPERIMENTAL_DISABLE_GC_DURING_IMPORT -is true. - -Some libraries like SeqIO have lots of side-effects during import time. -In some cases, disabling garbage collection for each top-level import can save -minutes of startup time. - -This should be _relatively_ safe, because we don't expect that it's often that -1. There's sufficient memory pressure during an import to cause an OOM, and -2. That memory pressure would have been sufficiently alleviated by garbage - collection. -""" - -import builtins -import contextlib -import gc -import os - - -@contextlib.contextmanager -def disabled_gc(): - """When used as context manager, prevents garbage collection in scope.""" - if not gc.isenabled(): - # GC is already disabled; don't make any changes. - yield - return - - gc.disable() - try: - yield - finally: - # We know that the original state was enabled because - # we didn't return above. - gc.enable() - - -_original_importlib_import = builtins.__import__ - - -def gc_disabled_import(*args, **kwargs): - with disabled_gc(): - return _original_importlib_import(*args, **kwargs) - - -def try_disable_gc_during_import(): - if os.environ.get('EXPERIMENTAL_DISABLE_GC_DURING_IMPORT'): - builtins.__import__ = gc_disabled_import diff --git a/t5x-main/t5x/disable_gc_during_import_test.py b/t5x-main/t5x/disable_gc_during_import_test.py deleted file mode 100644 index 48ebacd3904435309de055b0b618530fd8469666..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/disable_gc_during_import_test.py +++ /dev/null @@ -1,81 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for disable_gc_during_import.""" - -# pylint: disable=g-import-not-at-top,unused-import - -import builtins -import gc -import importlib -import os -import sys -from absl.testing import absltest -from absl.testing import parameterized -from t5x import disable_gc_during_import - -_ORIGINAL_BUILTIN_IMPORT_FN = builtins.__import__ - - -def assert_gc_disabled_during_import(): - # Side effect of importing module is asserting gc is disabled. - if sys.modules.get("t5x.assert_gc_disabled_during_import_test_util"): - sys.modules.pop("t5x.assert_gc_disabled_during_import_test_util", None) - - import t5x.assert_gc_disabled_during_import_test_util - - -class DisableGcDuringImportTest(parameterized.TestCase): - - def setUp(self): - super(DisableGcDuringImportTest, self).setUp() - builtins.__import__ = _ORIGINAL_BUILTIN_IMPORT_FN - os.environ["EXPERIMENTAL_DISABLE_GC_DURING_IMPORT"] = "true" - - def tearDown(self): - super(DisableGcDuringImportTest, self).tearDown() - builtins.__import__ = _ORIGINAL_BUILTIN_IMPORT_FN - os.environ.pop("EXPERIMENTAL_DISABLE_GC_DURING_IMPORT") - - def test_gc_enabled_after_one_import_import_builtin(self): - disable_gc_during_import.try_disable_gc_during_import() - - self.assertTrue(gc.isenabled()) - # Some arbitrary import; not particularly important. - import enum - - assert_gc_disabled_during_import() - - self.assertTrue(gc.isenabled()) - - def test_gc_enabled_after_two_imports_import_builtin(self): - disable_gc_during_import.try_disable_gc_during_import() - # from t5x import disable_gc_during_import - - self.assertTrue(gc.isenabled()) - # Some arbitrary imports; not particularly important which ones. - import contextlib - import enum - - assert_gc_disabled_during_import() - - self.assertTrue(gc.isenabled()) - - def test_test_utils_appropriately_detect_when_gc_enabled(self): - with self.assertRaisesRegex(ValueError, "Expected gc to be disabled"): - assert_gc_disabled_during_import() - - -if __name__ == "__main__": - absltest.main() diff --git a/t5x-main/t5x/eval.py b/t5x-main/t5x/eval.py deleted file mode 100644 index aa9dd454ad9dfd0777b734d45eac175f328249bc..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/eval.py +++ /dev/null @@ -1,473 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# pylint:disable=line-too-long -# pyformat: disable -r"""Runs training- and inference-evaluation on a T5X-compatible model. - -""" - -# pyformat: enable -# pylint:enable=line-too-long -import functools -import os -import re -from typing import Callable, Collection, Mapping, Optional, Sequence, Set, Tuple, Type - -# pylint:disable=g-import-not-at-top -from absl import logging -from clu import metric_writers -import jax -import seqio -from t5x import checkpoints -from t5x import gin_utils -from t5x import models -from t5x import partitioning -from t5x import train_state as train_state_lib -from t5x import trainer as trainer_lib -from t5x import utils -import tensorflow as tf -from tensorflow.io import gfile -from typing_extensions import Protocol -# pylint:enable=g-import-not-at-top - -# Automatically search for gin files relative to the T5X package. -_DEFAULT_GIN_SEARCH_PATHS = [ - os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -] - - -class SummarizeConfigFn(Protocol): - - def __call__( - self, - model_dir: str, - summary_writer: Optional[metric_writers.SummaryWriter], - step: int, - ) -> None: - ... - - -class InferenceEvaluator: - """Runs evaluation of the model against a given SeqIo task.""" - - def __init__( - self, - infer_eval_dataset_cfg: utils.DatasetConfig, - inference_evaluator_cls: utils.EvaluatorConstructor, - model: models.BaseModel, - partitioner: partitioning.BasePartitioner, - log_dir: Optional[str] = None, - verify_matching_vocabs_fn: Optional[ - Callable[[utils.DatasetConfig, models.BaseTransformerModel], None] - ] = utils.verify_matching_vocabs, - ): - """Constructs inference evaluator. - - Args: - infer_eval_dataset_cfg: Specification for the dataset to evaluate with - using the inference metrics (e.g., uses sampled decoding). If None, - inference eval is disabled. - inference_evaluator_cls: seqio.Evaluator class to use for inference - evaluation, potentially with bound configuration args. - model: Model to be evaluated. - partitioner: the partitioner to use. - log_dir: Parent directory to log evaluation results. - verify_matching_vocabs_fn: Function to validate whether the task - vocabulary matches the model vocabulary. Should raise an exception on - error. - """ - if verify_matching_vocabs_fn and isinstance( - model, models.BaseTransformerModel - ): - verify_matching_vocabs_fn(infer_eval_dataset_cfg, model) - - self._model = model - self._partitioner = partitioner - self._infer_eval_dataset_cfg = infer_eval_dataset_cfg - kwargs = {} - if log_dir: - kwargs['log_dir'] = os.path.join(log_dir, 'inference_eval') - else: - # Disable loggers if log dir is not provided. - kwargs['logger_cls'] = () - self._seqio_evaluator = inference_evaluator_cls( - mixture_or_task_name=infer_eval_dataset_cfg.mixture_or_task_name, - feature_converter=model.FEATURE_CONVERTER_CLS(pack=False), - eval_split=infer_eval_dataset_cfg.split, - use_cached=infer_eval_dataset_cfg.use_cached, - seed=infer_eval_dataset_cfg.seed, - sequence_length=infer_eval_dataset_cfg.task_feature_lengths, - use_memory_cache=infer_eval_dataset_cfg.use_memory_cache, - **kwargs, - ) - # Lazily initialized upon the first `evaluate` call. - self._predict_fn = None - self._predict_with_aux_fn = None - self._score_fn = None - - @property - def model_feature_shapes(self) -> Mapping[str, Tuple[int, ...]]: - return self._seqio_evaluator.model_feature_shapes - - @property - def eval_tasks(self) -> Sequence[seqio.Task]: - return self._seqio_evaluator.eval_tasks - - def close(self): - self._seqio_evaluator.close() - - def evaluate( - self, - train_state: train_state_lib.TrainState, - train_state_axes: train_state_lib.TrainState, - ) -> seqio.evaluation.AllMetricsFuture: - """Runs the prediction based inference eval. - - Args: - train_state: Training state to run evaluation of. - train_state_axes: partitioning info for the train state to be used. - - Returns: - A dictionary of training eval metrics. - """ - if not self._predict_fn: - self._predict_fn = utils.get_infer_fn( - infer_step=self._model.predict_batch, - batch_size=self._infer_eval_dataset_cfg.batch_size, - train_state_axes=train_state_axes, - partitioner=self._partitioner, - ) - - self._predict_with_aux_fn = utils.get_infer_fn( - infer_step=self._model.predict_batch_with_aux, - batch_size=self._infer_eval_dataset_cfg.batch_size, - train_state_axes=train_state_axes, - partitioner=self._partitioner, - ) - - self._score_fn = utils.get_infer_fn( - infer_step=self._model.score_batch, - batch_size=self._infer_eval_dataset_cfg.batch_size, - train_state_axes=train_state_axes, - partitioner=self._partitioner, - ) - - all_metrics, _ = self._seqio_evaluator.evaluate( - compute_metrics=jax.process_index() == 0, - step=int(utils.get_local_data(train_state.step)), - predict_fn=functools.partial( - self._predict_fn, train_state=train_state, rng=jax.random.PRNGKey(0) - ), - score_fn=functools.partial(self._score_fn, train_state=train_state), - predict_with_aux_fn=functools.partial( - self._predict_with_aux_fn, - train_state=train_state, - rng=jax.random.PRNGKey(0), - ), - ) - return all_metrics - - -def _sorted_ckpt_paths(ckpt_paths: Collection[str]) -> Sequence[str]: - def _extract_ckpt_step(ckpt_path: str) -> int: - # Steps may be prefixed with "checkpoint_", "model.ckpt-" or nothing. - match = re.search(r'(checkpoint_|model.ckpt-)?(\d+)\/?$', ckpt_path) - if match is None: - raise ValueError(f'Invalid checkpoint path: {ckpt_path}') - assert match is not None - return int(match.group(2)) - - return sorted(ckpt_paths, key=_extract_ckpt_step) - - -def _load_evaluated_ckpt_paths(eval_ckpt_path: str) -> Set[str]: - if not gfile.exists(eval_ckpt_path): - return set() - with gfile.GFile(eval_ckpt_path, 'r') as f: - return set(f.read().split()) - - -def evaluate( - *, - model: models.BaseTransformerModel, - dataset_cfg: utils.DatasetConfig, - restore_checkpoint_cfg: utils.RestoreCheckpointConfig, - partitioner: partitioning.BasePartitioner, - output_dir: str, - inference_evaluator_cls: Optional[ - utils.EvaluatorConstructor - ] = seqio.Evaluator, - training_evaluator_cls: Optional[Type[trainer_lib.Trainer]] = None, - summarize_config_fn: SummarizeConfigFn = gin_utils.summarize_gin_config, - train_state_initializer_cls: Type[ - utils.TrainStateInitializer - ] = utils.TrainStateInitializer, - train_eval_get_dataset_fn: utils.GetEvalDatasetCallable = utils.get_training_eval_datasets, - fallback_init_rng: Optional[int] = None, - use_orbax: bool = True, -): - """Evaluation function. - - Args: - model: The model object to use for inference. - dataset_cfg: Specification for the dataset to infer based on. - restore_checkpoint_cfg: Specification for the model parameter checkpoint to - load. - partitioner: Partitioner for the model parameters and data across devices. - output_dir: Path to directory to write temporary files and final results. - inference_evaluator_cls: seqio.Evaluator class to use for inference - evaluation, potentially with bound configuration args. - training_evaluator_cls: an optional Trainer class to use for training - evaluation, potentially with bound configuration args. - summarize_config_fn: A function that takes in the model directory, an - optional SummaryWriter, and the step number, and writes a summary of the - configuration. SummaryWriter will be None in most cases. - train_state_initializer_cls: t5x.utils.TrainStateInitializer class for - initializing partitioned TrainState from checkpoints or scratch. - train_eval_get_dataset_fn: Optional callable use to get the train-eval - datasets based on the DatasetConfig and shard information. If missing, it - defaults to `utils.get_training_eval_datasets`. - fallback_init_rng: A random seed used for parameter initialization during - model re-loading when utils.RestoreCheckpointConfig.fallback_to_scratch is - set to True. If None, parameter initialization is not allowed during model - loading and having fallback_to_scratch enabled will result in an error. - use_orbax: if True, uses Orbax for checkpointing. Experimental feature. - """ - jax.monitoring.record_event('/jax/t5x/evaluate/beacon') - logging.info('Process ID: %d', jax.process_index()) - if dataset_cfg.module: - utils.import_module(dataset_cfg.module) - batch_size = dataset_cfg.batch_size - - summarize_config_fn(model_dir=output_dir, summary_writer=None, step=0) - - evaluator = InferenceEvaluator( - dataset_cfg, - inference_evaluator_cls, - model, - partitioner, - log_dir=output_dir, - ) - if not evaluator.eval_tasks: - raise ValueError( - f"'{dataset_cfg.mixture_or_task_name}' has no metrics for evaluation, " - "or this mixture/task doesn't have provided split." - ) - - # ---------------------------------------------------------------------------- - # T5X model loading. - # ---------------------------------------------------------------------------- - - # Initialize optimizer from the existing checkpoint. - input_shapes = { - k: (batch_size,) + s for k, s in evaluator.model_feature_shapes.items() - } - - train_state_initializer = train_state_initializer_cls( - optimizer_def=None, # Do not load optimizer state. - init_fn=model.get_initial_variables, - input_shapes=input_shapes, - partitioner=partitioner, - ) - train_state_axes = train_state_initializer.train_state_axes - # Log the variable shapes information and write to a file. - log_file = os.path.join(output_dir, 'model-info.txt') - utils.log_model_info( - log_file, train_state_initializer.global_train_state_shape, partitioner - ) - - if training_evaluator_cls: - data_layout = partitioner.get_data_layout(dataset_cfg.batch_size) - train_eval_datasets = train_eval_get_dataset_fn( # pytype:disable=missing-parameter - dataset_cfg, - data_layout.shard_id, - data_layout.num_shards, - feature_converter_cls=model.FEATURE_CONVERTER_CLS, - ) - - train_evaluator = training_evaluator_cls( # pytype:disable=wrong-arg-types - model=model, - train_state=None, # Will replace later. - partitioner=partitioner, - train_state_axes=train_state_axes, - eval_names=train_eval_datasets.keys(), - summary_dir=output_dir, - rng=jax.random.PRNGKey(0), # unused - learning_rate_fn=None, # unused - num_microbatches=None, # unused - ) - - def _maybe_run_train_eval(train_state: train_state_lib.TrainState): - if training_evaluator_cls: - train_evaluator.train_state = train_state - train_evaluator.eval({ - task: ( - ds.as_numpy_iterator() if isinstance(ds, tf.data.Dataset) else ds - ) - for task, ds in train_eval_datasets.items() - }) - - # Disable strictness since we are dropping the optimizer state. - restore_checkpoint_cfg.strict = False - - # Skip checkpoints that have already been evaluated. - eval_ckpt_path = os.path.join( - output_dir, f'eval.{dataset_cfg.mixture_or_task_name}.ckpt' - ) - if restore_checkpoint_cfg.mode == 'all' and gfile.exists(eval_ckpt_path): - logging.info('Found evaluation checkpoint: %s', eval_ckpt_path) - - ckpt_dirs = ( - [restore_checkpoint_cfg.path] - if isinstance(restore_checkpoint_cfg.path, str) - else restore_checkpoint_cfg.path - ) - ckpt_paths = set() - for ckpt_dir in ckpt_dirs: - if not gfile.isdir(ckpt_dir): - raise ValueError( - f"Checkpoint path '{ckpt_dir}' must be a valid directory when " - "using restore mode 'all'." - ) - ckpt_paths.update( - checkpoints.get_checkpoint_dir(ckpt_dir, step) - for step in checkpoints.all_steps(ckpt_dir) - ) - - evaluated_ckpt_paths = _load_evaluated_ckpt_paths(eval_ckpt_path) - - logging.info( - 'Skipping evaluated checkpoints:\n %s', - '\n '.join(_sorted_ckpt_paths(ckpt_paths & evaluated_ckpt_paths)), - ) - ckpt_paths = _sorted_ckpt_paths(ckpt_paths - evaluated_ckpt_paths) - restore_cfg = restore_checkpoint_cfg - restore_cfg.mode = 'specific' - - else: - restore_cfg, ckpt_paths = utils.get_first_valid_restore_config_and_paths( - [restore_checkpoint_cfg] - ) - - if fallback_init_rng is not None: - fallback_init_rng = jax.random.PRNGKey(fallback_init_rng) - - for ckpt_path in ckpt_paths: - train_state, _ = utils.create_checkpoint_manager_and_restore( - train_state_initializer, - partitioner, - restore_cfg, - ckpt_path, - fallback_init_rng, - use_orbax=use_orbax, - ) - if train_state is None: - raise ValueError('Failed to restore checkpoint.') - - # ---------------------------------------------------------------------------- - # Main evaluation loop - # ---------------------------------------------------------------------------- - - # Run final evaluation (with decoding) on the full eval dataset. - host_step = int(utils.get_local_data(train_state.step)) - _maybe_run_train_eval(train_state) - all_metrics = evaluator.evaluate(train_state, train_state_axes) - all_metrics.result() # Ensure metrics are finished being computed. - # Wait until computations are done before continuing. - utils.sync_global_devices(f'step_{host_step}:complete') - if jax.process_index() == 0: - # Read/write/replace rather than append to avoid filesystem issue. - evaluated_ckpt_paths = _load_evaluated_ckpt_paths(eval_ckpt_path) - evaluated_ckpt_paths.add(ckpt_path) - with gfile.GFile(eval_ckpt_path, 'w') as f: - f.write('\n'.join(_sorted_ckpt_paths(evaluated_ckpt_paths))) - - logging.info('Finished.') - - -if __name__ == '__main__': - # pylint:disable=g-import-not-at-top - from absl import app - from absl import flags - import fiddle as fdl - import gin - from t5x import config_utils - - FLAGS = flags.FLAGS - - flags.DEFINE_multi_string( - 'gin_file', - default=None, - help=( - 'Path to gin configuration file. Multiple paths may be passed and ' - 'will be imported in the given order, with later configurations ' - 'overriding earlier ones.' - ), - ) - - flags.DEFINE_multi_string( - 'gin_bindings', default=[], help='Individual gin bindings.' - ) - - flags.DEFINE_list( - 'gin_search_paths', - default=['.'], - help=( - 'Comma-separated list of gin config path prefixes to be prepended ' - 'to suffixes given via `--gin_file`. If a file appears in. Only the ' - 'first prefix that produces a valid path for each suffix will be ' - 'used.' - ), - ) - - flags.DEFINE_string( - 'tfds_data_dir', - None, - 'If set, this directory will be used to store datasets prepared by ' - 'TensorFlow Datasets that are not available in the public TFDS GCS ' - 'bucket. Note that this flag overrides the `tfds_data_dir` attribute of ' - 'all `Task`s.', - ) - - - def main(argv: Sequence[str]): - """Wrapper for pdb post mortems.""" - _main(argv) - - def _main(argv: Sequence[str]): - """True main function.""" - if len(argv) > 1: - raise app.UsageError('Too many command-line arguments.') - - if FLAGS.tfds_data_dir: - seqio.set_tfds_data_dir_override(FLAGS.tfds_data_dir) - - if config_utils.using_fdl(): - config = config_utils.config_with_fiddle(evaluate) - evaluate_using_fiddle = fdl.build(config) - evaluate_using_fiddle() - else: - # Create gin-configurable version of `eval`. - evaluate_using_gin = gin.configurable(evaluate) - - gin_utils.parse_gin_flags( - # User-provided gin paths take precedence if relative paths conflict. - FLAGS.gin_search_paths + _DEFAULT_GIN_SEARCH_PATHS, - FLAGS.gin_file, - FLAGS.gin_bindings, - ) - evaluate_using_gin() - - config_utils.run(main) diff --git a/t5x-main/t5x/examples/__init__.py b/t5x-main/t5x/examples/__init__.py deleted file mode 100644 index 548e50465de0fcf5c81a4b08186d8164f705908d..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""This empty file is needed to be recognized as a package by the setuptools.""" diff --git a/t5x-main/t5x/examples/decoder_only/examples/base_wmt_from_scratch.gin b/t5x-main/t5x/examples/decoder_only/examples/base_wmt_from_scratch.gin deleted file mode 100644 index 9d6adb0cc87ea1609c04f0872a4d3b4cce9bbf21..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/decoder_only/examples/base_wmt_from_scratch.gin +++ /dev/null @@ -1,64 +0,0 @@ -from __gin__ import dynamic_registration - -import __main__ as train_script -import seqio -from t5.data import mixtures -from t5x import models -from t5x import partitioning -from t5x import utils - -include "t5x/examples/decoder_only/models/base.gin" -include "t5x/configs/runs/pretrain.gin" - -MIXTURE_OR_TASK_NAME = "wmt_t2t_ende_v003" -TASK_FEATURE_LENGTHS = {"inputs": 256, "targets": 256} -TRAIN_STEPS = 50000 -DROPOUT_RATE = 0.0 - -train/utils.DatasetConfig: - batch_size = 128 - use_cached = False - pack = True - seed = 0 - -train_eval/utils.DatasetConfig: - batch_size = 128 - use_cached = False - pack = True - seed = 0 - -infer_eval/utils.DatasetConfig: - mixture_or_task_name = %MIXTURE_OR_TASK_NAME - task_feature_lengths = None # compute max - split = "validation" - seed = 0 - batch_size = 128 - shuffle = False - use_cached = False - -train_script.train: - eval_period = 500 - eval_steps = 20 - random_seed = 0 - use_hardware_rng = True - infer_eval_dataset_cfg = @infer_eval/utils.DatasetConfig() - inference_evaluator_cls = @seqio.Evaluator - -seqio.Evaluator: - logger_cls = [@seqio.PyLoggingLogger, @seqio.TensorBoardLogger, @seqio.JSONLogger] - num_examples = None # Use all examples in the infer_eval dataset. - use_memory_cache = True - -utils.SaveCheckpointConfig: - period = 5000 # checkpoint frequency - -# `num_decodes` is equivalent to a beam size in a beam search decoding. -models.DecoderOnlyModel.predict_batch_with_aux.num_decodes = 8 -models.DecoderOnlyModel.inputs_bidirectional_attention = True - -partitioning.PjitPartitioner.num_partitions = 2 - -utils.create_learning_rate_scheduler: - factors = 'constant * rsqrt_decay' - base_learning_rate = 1.0 - warmup_steps = 10000 diff --git a/t5x-main/t5x/examples/decoder_only/layers.py b/t5x-main/t5x/examples/decoder_only/layers.py deleted file mode 100644 index 9980fec0cf1052d514a53b182388aa67f90aed33..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/decoder_only/layers.py +++ /dev/null @@ -1,1180 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Dense attention classes and mask/weighting functions.""" - -# pylint: disable=attribute-defined-outside-init,g-bare-generic - -import dataclasses -import functools -import operator -from typing import Any, Callable, Iterable, Optional, Sequence, Tuple, Union - -from flax import linen as nn -import flax.core.variables as variables -from flax.linen import partitioning as nn_partitioning -from flax.training import common_utils -import jax -from jax import lax -from jax import random -import jax.numpy as jnp -import numpy as np - - -# from flax.linen.partitioning import param_with_axes, with_sharding_constraint -param_with_axes = nn_partitioning.param_with_axes -with_sharding_constraint = nn_partitioning.with_sharding_constraint - - -# Type annotations -Array = jnp.ndarray -DType = jnp.dtype -PRNGKey = jnp.ndarray -Shape = Sequence[int] -Activation = Callable[..., Array] -# Parameter initializers. -Initializer = Callable[[PRNGKey, Shape, DType], Array] - -default_embed_init = nn.initializers.variance_scaling( - 1.0, 'fan_in', 'normal', out_axis=0 -) - - -def dot_product_attention( - query: Array, - key: Array, - value: Array, - bias: Optional[Array] = None, - dropout_rng: Optional[PRNGKey] = None, - dropout_rate: float = 0.0, - deterministic: bool = False, - dtype: DType = jnp.float32, - float32_logits: bool = False, -): - """Computes dot-product attention given query, key, and value. - - This is the core function for applying attention based on - https://arxiv.org/abs/1706.03762. It calculates the attention weights given - query and key and combines the values using the attention weights. - - Args: - query: queries for calculating attention with shape of `[batch, q_length, - num_heads, qk_depth_per_head]`. - key: keys for calculating attention with shape of `[batch, kv_length, - num_heads, qk_depth_per_head]`. - value: values to be used in attention with shape of `[batch, kv_length, - num_heads, v_depth_per_head]`. - bias: bias for the attention weights. This should be broadcastable to the - shape `[batch, num_heads, q_length, kv_length]` This can be used for - incorporating causal masks, padding masks, proximity bias, etc. - dropout_rng: JAX PRNGKey: to be used for dropout - dropout_rate: dropout rate - deterministic: bool, deterministic or not (to apply dropout) - dtype: the dtype of the computation (default: float32) - float32_logits: bool, if True then compute logits in float32 to avoid - numerical issues with bfloat16. - - Returns: - Output of shape `[batch, length, num_heads, v_depth_per_head]`. - """ - assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.' - assert ( - query.shape[:-3] == key.shape[:-3] == value.shape[:-3] - ), 'q, k, v batch dims must match.' - assert ( - query.shape[-2] == key.shape[-2] == value.shape[-2] - ), 'q, k, v num_heads must match.' - assert key.shape[-3] == value.shape[-3], 'k, v lengths must match.' - assert query.shape[-1] == key.shape[-1], 'q, k depths must match.' - - # Casting logits and softmax computation for float32 for model stability. - if float32_logits: - query = query.astype(jnp.float32) - key = key.astype(jnp.float32) - - # `attn_weights`: [batch, num_heads, q_length, kv_length] - attn_weights = jnp.einsum('bqhd,bkhd->bhqk', query, key) - - # Apply attention bias: masking, dropout, proximity bias, etc. - if bias is not None: - attn_weights = attn_weights + bias.astype(attn_weights.dtype) - - # Normalize the attention weights across `kv_length` dimension. - attn_weights = jax.nn.softmax(attn_weights).astype(dtype) - - # Apply attention dropout. - if not deterministic and dropout_rate > 0.0: - keep_prob = 1.0 - dropout_rate - # T5 broadcasts along the "length" dim, but unclear which one that - # corresponds to in positional dimensions here, assuming query dim. - dropout_shape = list(attn_weights.shape) - dropout_shape[-2] = 1 - keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape) - keep = jnp.broadcast_to(keep, attn_weights.shape) - multiplier = keep.astype(attn_weights.dtype) / jnp.asarray( - keep_prob, dtype=dtype - ) - attn_weights = attn_weights * multiplier - - # Take the linear combination of `value`. - return jnp.einsum('bhqk,bkhd->bqhd', attn_weights, value) - - -class MultiHeadDotProductAttention(nn.Module): - """Multi-head dot-product attention. - - Attributes: - num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1]) - should be divisible by the number of heads. - head_dim: dimension of each head. - dtype: the dtype of the computation. - dropout_rate: dropout rate - kernel_init: initializer for the kernel of the Dense layers. - float32_logits: bool, if True then compute logits in float32 to avoid - numerical issues with bfloat16. - """ - - num_heads: int - head_dim: int - dtype: DType = jnp.float32 - dropout_rate: float = 0.0 - kernel_init: Initializer = nn.initializers.variance_scaling( - 1.0, 'fan_in', 'normal' - ) - float32_logits: bool = False - - def update_cache_prefill( - self, - key: Array, - value: Array, - cached_key: variables.Variable, - cached_value: variables.Variable, - cache_index: variables.Variable, - prefill_lengths: Array, - ) -> Tuple[Array, Array, Array, Array, Array, Array]: - """Update the autoregressive cache for multiple timesteps at once. - - This is useful for things like a prefix-lm where the encoder section of the - input is visible bidirectionally. The key and value for this section need to - be computed in a single shot, as a step by step approach would result in - causal attention. - - Args: - key: The calculated key used in attention. [batch..., length, num_heads, - features_per_head] - value: The calculated value used in attention. [batch..., length, - num_heads, features_per_head] - cached_key: The cache of previous keys. [batch..., num_heads, - features_per_head, length] - cached_value: The cache of previous values. [batch..., num_heads, - features_per_head, length] - cache_index: The timestep that we are currently calculating the key and - value for. [batch] - prefill_lengths: The number of timesteps we should fill in the cache. - [batch] - - Returns: - The key, value, and the last timestep we just filled in the cache. - We also return the new cache values for now because assigning to a - variable inside of a method doesn't work. These returns will be removed - eventually. - """ - # Make a reference to the data underlaying the variable for ease of - # use. - cache_index.value = prefill_lengths - # Note, the cache index is now a vector of batch size so that each example - # can start just after its prefix, which can be different lengths for - # different examples. - cur_index = cache_index.value - # Move the sequence dimension to the end to match the cache shapes. - key_cached = jnp.moveaxis(key, -3, -1) - value_cached = jnp.moveaxis(value, -3, -1) - # Reshape the index so the batch is at the beginning. The default - # broadcasting behavior is to add singleton dims to the front, but we need - # them at the end. - batch_first_index = jnp.reshape( - cur_index, (-1,) + tuple(1 for _ in range(cached_key.value.ndim - 1)) - ) - # Calculate a mask that will set any position past the prefix to zero - # when applied to the key. - key_mask = ( - lax.broadcasted_iota( - jnp.int32, cached_key.value.shape, cached_key.value.ndim - 1 - ) - < batch_first_index - ) - value_mask = ( - lax.broadcasted_iota( - jnp.int32, cached_value.value.shape, cached_value.value.ndim - 1 - ) - < batch_first_index - ) - # Set the caches with the calculated key and values but hide anything - # past the prefix. - cached_key_value = key_cached * key_mask - cached_value_value = value_cached * value_mask - # TODO(hwchung): remove the return values once direct assignment to - # variables inside a method is possible. - return ( - key, - value, - cur_index, - cached_key_value, - cached_value_value, - prefill_lengths, - ) - - def update_cache_decode( - self, - key: Array, - value: Array, - cached_key: variables.Variable, - cached_value: variables.Variable, - cache_index: variables.Variable, - ) -> Tuple[Array, Array, Array, Array, Array, Array]: - """Update the next timestep in the autoregressive cache. - - This is used during step by step decoding where each key and value we get - are a single (the next) timestep. - - Args: - key: The calculated key used in attention. [batch..., 1, num_heads, - features_per_head] - value: The calculated value used in attention. [batch..., 1, num_heads, - features_per_head] - cached_key: The cache of previous keys. [batch..., num_heads, - features_per_head, length] - cached_value: The cache of previous values. [batch..., num_heads, - features_per_head, length] - cache_index: The timestep that we are currently calculating the key and - value for. [batch] if we are decoding after doing a prefill or [1] if we - are starting with step-by-step decoding. - - Returns: - The key, value, and the last timestep we just filled in the cache. Note: - this index is the last timestep we just fill, the actual value of the - `cache_index` is already increased to point to the next timestep to fill. - We also return the new cache values for now because assigning to a - variable inside of a method doesn't work. These returns will be removed - eventually. - """ - cache_length = cached_key.value.shape[-1] - # Create a OHE of the current index. NOTE: the index is increased - # below. - # Note: We reshape the index into a column vector so that it will work - # if the index is a scalar or a vector with different cache positions - # from different elements in a batch. - cur_index = jnp.reshape(cache_index.value, (-1,)) - one_hot_indices = jax.nn.one_hot(cur_index, cache_length, dtype=key.dtype) - # In order to update the key, value caches with the current key and - # value, we move the length axis to the back, similar to what we did - # for the cached ones above. - # Note these are currently the key and value of a single position, - # since we feed one position at a time. - one_token_key = jnp.moveaxis(key, -3, -1) - one_token_value = jnp.moveaxis(value, -3, -1) - # The one hot indices are now either [1, length] for a scalar index or - # [batch size, length] for examples where there are different lengths - # of prefixes. We need to add dims for num_heads and num_features as - # broadcasting doesn't work for the batched version. - one_hot_indices = jnp.expand_dims( - jnp.expand_dims(one_hot_indices, axis=1), axis=1 - ) - # Update key, value caches with our new 1d spatial slices. - # We implement an efficient scatter into the cache via one-hot - # broadcast and addition. - # Key/Value have seq lengths of 1 while one_hot has a seq_length - # of length. key/value will broadcast their value to each timestep - # and the onehot will mask all but the correct timesteps. - key = cached_key.value + one_token_key * one_hot_indices - value = cached_value.value + one_token_value * one_hot_indices - cached_key_value = key - cached_value_value = value - cache_index_value = cache_index.value + 1 - # Move the keys and values back to their original shapes. - key = jnp.moveaxis(key, -1, -3) - value = jnp.moveaxis(value, -1, -3) - # TODO(hwchung): remove the return values once direct assignment to - # variables inside a method is possible. - return ( - key, - value, - cur_index, - cached_key_value, - cached_value_value, - cache_index_value, - ) - - @nn.compact - def __call__( - self, - inputs_q: Array, - inputs_kv: Array, - mask: Optional[Array] = None, - bias: Optional[Array] = None, - *, - decode: bool = False, - deterministic: bool = False, - prefill: bool = False, - prefill_lengths: Optional[Array] = None, - ) -> Array: - """Applies multi-head dot product attention on the input data. - - Projects the inputs into multi-headed query, key, and value vectors, - applies dot-product attention and project the results to an output vector. - - There are two modes: decoding and non-decoding (e.g., training). The mode is - determined by `decode`. - - During decoding mode, this method is called twice, by `init` and - `apply`. In the former, inputs_q: `[batch..., length, qkv_features]` and - inputs_kv: `[batch..., length, qkv_features]`. - - During apply, query, key and value all have the shape: `[batch * beam, 1, - qkv_features]` where the batch dimension is added to include multiple beams. - Note that the batch dimension is different during the `init` and `apply` - calls. This is because the cached variables are directly passed-in during - `apply` method. In other words, the cache variables such as `cached_key` are - initialized with `batch` dim, expanded by tiling in the beam search function - to `batch * beam` dimension, and passed to the `apply` method as part of a - variable dict. - - Args: - inputs_q: input queries of shape `[batch, q_length, embed]`. - inputs_kv: key/values of shape `[batch, kv_length, embed]`. - mask: attention mask of shape `[batch, num_heads, q_length, kv_length]`. - bias: attention bias of shape `[batch, num_heads, q_length, kv_length]`. - decode: whether to prepare and use an autoregressive cache. - deterministic: whether deterministic or not (to apply dropout) - prefill: whether to run a partial sequence to prefill the cache. - prefill_lengths: an array of shape [batch] denoting the length of each - partial sequence we are filling in the cache. - - Returns: - output of shape `[batch, q_length, embed]`. - """ - projection = functools.partial( - DenseGeneral, - axis=-1, - features=(self.num_heads, self.head_dim), - kernel_axes=('embed', 'joined_kv'), - dtype=self.dtype, - ) - - # NOTE: T5 does not explicitly rescale the attention logits by - # 1/sqrt(depth_kq)! This is folded into the initializers of the - # linear transformations, which is equivalent under Adafactor. - depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype) - query_init = lambda *args: self.kernel_init(*args) / depth_scaling - - # Project inputs_q to multi-headed q/k/v - # dimensions are then [batch, length, num_heads, head_dim] - query = projection(kernel_init=query_init, name='query')(inputs_q) - key = projection(kernel_init=self.kernel_init, name='key')(inputs_kv) - value = projection(kernel_init=self.kernel_init, name='value')(inputs_kv) - - query = with_sharding_constraint(query, ('batch', 'length', 'heads', 'kv')) - key = with_sharding_constraint(key, ('batch', 'length', 'heads', 'kv')) - value = with_sharding_constraint(value, ('batch', 'length', 'heads', 'kv')) - - if prefill and decode: - raise ValueError( - 'prefill and decode cannot both be true at the same' - 'time. If you are using a prefix LM with bidirectional ' - 'attention on the inputs, please make a call with ' - 'prefill=True that includes an attention mask that ' - 'covers your inputs first and then make your decoding ' - 'calls.' - ) - if prefill or decode: - # Detect if we're initializing by absence of existing cache data. - is_initialized = self.has_variable('cache', 'cached_key') - # The key and value have dimension - # [batch..., length, num_heads, features_per_head], but we cache them as - # [batch..., num_heads, features_per_head, length] as a TPU fusion - # optimization. This also enable the "scatter via one-hot broadcast" - # trick, which means we do a one-hot broadcast instead of a scatter/gather - # operations, which gives a 3-4x speedup in practice. - swap_dims = lambda x: x[:-3] + tuple(x[i] for i in [-2, -1, -3]) - cached_key = self.variable( - 'cache', 'cached_key', jnp.zeros, swap_dims(key.shape), key.dtype - ) - cached_value = self.variable( - 'cache', - 'cached_value', - jnp.zeros, - swap_dims(value.shape), - value.dtype, - ) - cache_index = self.variable( - 'cache', 'cache_index', lambda: jnp.array(0, dtype=jnp.int32) - ) - if is_initialized: - # Here we are in "apply()". - *batch_dims, num_heads, features_per_head, length = ( - cached_key.value.shape - ) - if prefill: - if prefill_lengths is None: - # Figure out how far each element in the batch fills the cache based - # on the mask. We index each element in the batch, the first head - # dim (because this is always set to one), and the first query - # vector. If there is any prefix at all, the first element in the - # prefix would be part of it. - prefill_lengths = jnp.sum(mask[:, 0, 0, :], axis=-1).astype( - cache_index.value.dtype - ) - ( - key, - value, - cur_index, - cached_key_value, - cached_value_value, - cache_index_value, - ) = self.update_cache_prefill( - key, value, cached_key, cached_value, cache_index, prefill_lengths - ) - # During fast autoregressive decoding, we feed one position at a time, - # and cache the keys and values step by step. - elif decode: - # Check the shape of the cached key against the input query. - expected_shape = tuple(batch_dims) + (1, num_heads, features_per_head) - if expected_shape != query.shape: - raise ValueError( - 'Autoregressive cache shape error, ' - 'expected query shape %s instead got %s.' - % (expected_shape, query.shape) - ) - ( - key, - value, - cur_index, - cached_key_value, - cached_value_value, - cache_index_value, - ) = self.update_cache_decode( - key, value, cached_key, cached_value, cache_index - ) - # Enforcing the Causal mask over previous positions and selecting only - # the bias value for the current index is only needed during decode - # mode where a single example is feed at a time. In prefill mode we - # uses these as provided, that same way it is done in a normal forward - # pass, like when computing logits during training. - - # Causal mask for cached decoder self-attention: our single query - # position should only attend to those key positions that have already - # been generated and cached, not the remaining zero elements. - - # (1, 1, length) represent (head dim, query length, key length) - # query length is 1 because during decoding we deal with one - # index. - # The same mask is applied to all batch elements and heads. - # - # Add trailing dims to the current index so it can either - # broadcast over the batch dim or it can just be batch size. - mask = combine_masks( - mask, - jnp.broadcast_to( - jnp.arange(length), tuple(batch_dims) + (1, 1, length) - ) - <= jnp.reshape(cur_index, (-1, 1, 1, 1)), - ) - # Grab the correct relative attention bias during decoding. This is - # only required during single step decoding. - if bias is not None: - # The bias is a full attention matrix, but during decoding we only - # have to take a slice of it. - # This is equivalent to `bias[..., cur_index:cur_index+1, :]`. If - # we are doing prefix decoding where `cur_index` is a vector the - # result will be `[batch, heads, 1, :]`. If `cur_index` is a scalar - # like in encdec decoding, the result will be `[1, heads, 1, :]`. - # We use a one-hot einsum rather than a slice to avoid introducing a - # Gather op that is currently lowered poorly by SPMD passes, adding - # expensive all-reduce and all-gather operations. - - bias = jnp.einsum( - 'bq, bhqk->bhk', - common_utils.onehot(cur_index, num_classes=length), - bias, - ) - bias = jnp.expand_dims(bias, 2) - - # Currently, updating a variable inside of a method is not handled - # in flax, so we return the actual values and assign them in the main - # compacted call for now. - # TODO(brianlester,levskaya): Move variable assignment inside of the - # cache update functions once variable references are tracked across - # transform boundaries. - cache_index.value = cache_index_value - cached_key.value = cached_key_value - cached_value.value = cached_value_value - - # Convert the boolean attention mask to an attention bias. - if mask is not None: - # attention mask in the form of attention bias - attention_bias = lax.select( - mask > 0, - jnp.full(mask.shape, 0.0).astype(self.dtype), - jnp.full(mask.shape, -1e10).astype(self.dtype), - ) - else: - attention_bias = None - - # Add provided bias term (e.g. relative position embedding). - if bias is not None: - attention_bias = combine_biases(attention_bias, bias) - - dropout_rng = None - if not deterministic and self.dropout_rate > 0.0: - dropout_rng = self.make_rng('dropout') - - # Apply attention. - x = dot_product_attention( - query, - key, - value, - bias=attention_bias, - dropout_rng=dropout_rng, - dropout_rate=self.dropout_rate, - deterministic=deterministic, - dtype=self.dtype, - float32_logits=self.float32_logits, - ) - - # Back to the original inputs dimensions. - out = DenseGeneral( - features=inputs_q.shape[-1], # output dim is set to the input dim. - axis=(-2, -1), - kernel_init=self.kernel_init, - kernel_axes=('joined_kv', 'embed'), - dtype=self.dtype, - name='out', - )(x) - return out - - -def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]: - # A tuple by convention. len(axes_tuple) then also gives the rank efficiently. - return tuple([ax if ax >= 0 else ndim + ax for ax in axes]) - - -def _canonicalize_tuple(x): - if isinstance(x, Iterable): - return tuple(x) - else: - return (x,) - - -# ------------------------------------------------------------------------------ -# DenseGeneral for attention layers. -# ------------------------------------------------------------------------------ -class DenseGeneral(nn.Module): - """A linear transformation (without bias) with flexible axes. - - Attributes: - features: tuple with numbers of output features. - axis: tuple with axes to apply the transformation on. - dtype: the dtype of the computation (default: float32). - kernel_init: initializer function for the weight matrix. - """ - - features: Union[Iterable[int], int] - axis: Union[Iterable[int], int] = -1 - dtype: DType = jnp.float32 - kernel_init: Initializer = nn.initializers.variance_scaling( - 1.0, 'fan_in', 'truncated_normal' - ) - kernel_axes: Tuple[str, ...] = () - - @nn.compact - def __call__(self, inputs: Array) -> Array: - """Applies a linear transformation to the inputs along multiple dimensions. - - Args: - inputs: The nd-array to be transformed. - - Returns: - The transformed input. - """ - features = _canonicalize_tuple(self.features) - axis = _canonicalize_tuple(self.axis) - - inputs = jnp.asarray(inputs, self.dtype) - axis = _normalize_axes(axis, inputs.ndim) - - kernel_shape = tuple([inputs.shape[ax] for ax in axis]) + features - kernel_param_shape = ( - np.prod([inputs.shape[ax] for ax in axis]), - np.prod(features), - ) - kernel = param_with_axes( - 'kernel', - self.kernel_init, - kernel_param_shape, - jnp.float32, - axes=self.kernel_axes, - ) - kernel = jnp.asarray(kernel, self.dtype) - kernel = jnp.reshape(kernel, kernel_shape) - - contract_ind = tuple(range(0, len(axis))) - return lax.dot_general(inputs, kernel, ((axis, contract_ind), ((), ()))) - - -def _convert_to_activation_function( - fn_or_string: Union[str, Callable] -) -> Callable: - """Convert a string to an activation function.""" - if fn_or_string == 'linear': - return lambda x: x - elif isinstance(fn_or_string, str): - return getattr(nn, fn_or_string) - elif callable(fn_or_string): - return fn_or_string - else: - raise ValueError( - "don't know how to convert %s to an activation function" - % (fn_or_string,) - ) - - -class MlpBlock(nn.Module): - """Transformer MLP / feed-forward block. - - Attributes: - intermediate_dim: Shared dimension of hidden layers. - activations: Type of activations for each layer. Each element is either - 'linear', a string function name in flax.linen, or a function. - kernel_init: Kernel function, passed to the dense layers. - deterministic: Whether the dropout layers should be deterministic. - intermediate_dropout_rate: Dropout rate used after the intermediate layers. - dtype: Type for the dense layer. - """ - - intermediate_dim: int = 2048 - activations: Sequence[Union[str, Callable]] = ('relu',) - kernel_init: Initializer = nn.initializers.variance_scaling( - 1.0, 'fan_in', 'truncated_normal' - ) - intermediate_dropout_rate: float = 0.1 - dtype: Any = jnp.float32 - - @nn.compact - def __call__(self, inputs, decode: bool = False, deterministic: bool = False): - """Applies Transformer MlpBlock module.""" - # Iterate over specified MLP input activation functions. - # e.g. ('relu',) or ('gelu', 'linear') for gated-gelu. - activations = [] - for idx, act_fn in enumerate(self.activations): - dense_name = 'wi' if len(self.activations) == 1 else f'wi_{idx}' - x = DenseGeneral( - self.intermediate_dim, - dtype=self.dtype, - kernel_init=self.kernel_init, - kernel_axes=('embed', 'mlp'), - name=dense_name, - )(inputs) - x = _convert_to_activation_function(act_fn)(x) - activations.append(x) - - # Take elementwise product of above intermediate activations. - x = functools.reduce(operator.mul, activations) - # Apply dropout and final dense output projection. - x = nn.Dropout(rate=self.intermediate_dropout_rate, broadcast_dims=(-2,))( - x, deterministic=deterministic - ) # Broadcast along length. - x = with_sharding_constraint(x, ('batch', 'length', 'mlp')) - output = DenseGeneral( - inputs.shape[-1], - dtype=self.dtype, - kernel_init=self.kernel_init, - kernel_axes=('mlp', 'embed'), - name='wo', - )(x) - return output - - -class Embed(nn.Module): - """A parameterized function from integers [0, n) to d-dimensional vectors. - - Attributes: - num_embeddings: number of embeddings. - features: number of feature dimensions for each embedding. - dtype: the dtype of the embedding vectors (default: float32). - embedding_init: embedding initializer. - one_hot: performs the gather with a one-hot contraction rather than a true - gather. This is currently needed for SPMD partitioning. - """ - - num_embeddings: int - features: int - cast_input_dtype: Optional[DType] = None - dtype: DType = jnp.float32 - attend_dtype: Optional[DType] = None - embedding_init: Initializer = default_embed_init - one_hot: bool = False - embedding: Array = dataclasses.field(init=False) - - def setup(self): - self.embedding = param_with_axes( - 'embedding', - self.embedding_init, - (self.num_embeddings, self.features), - jnp.float32, - axes=('vocab', 'embed'), - ) - - def __call__(self, inputs: Array) -> Array: - """Embeds the inputs along the last dimension. - - Args: - inputs: input data, all dimensions are considered batch dimensions. - - Returns: - Output which is embedded input data. The output shape follows the input, - with an additional `features` dimension appended. - """ - if self.cast_input_dtype: - inputs = inputs.astype(self.cast_input_dtype) - if not jnp.issubdtype(inputs.dtype, jnp.integer): - raise ValueError('Input type must be an integer or unsigned integer.') - if self.one_hot: - iota = lax.iota(jnp.int32, self.num_embeddings) - one_hot = jnp.array(inputs[..., jnp.newaxis] == iota, dtype=self.dtype) - output = jnp.dot(one_hot, jnp.asarray(self.embedding, self.dtype)) - else: - output = jnp.asarray(self.embedding, self.dtype)[inputs] - output = with_sharding_constraint(output, ('batch', 'length', 'embed')) - return output - - def attend(self, query: Array) -> Array: - """Attend over the embedding using a query array. - - Args: - query: array with last dimension equal the feature depth `features` of the - embedding. - - Returns: - An array with final dim `num_embeddings` corresponding to the batched - inner-product of the array of query vectors against each embedding. - Commonly used for weight-sharing between embeddings and logit transform - in NLP models. - """ - dtype = self.attend_dtype if self.attend_dtype is not None else self.dtype - return jnp.dot(query, jnp.asarray(self.embedding, dtype).T) - - -class RelativePositionBiases(nn.Module): - """Adds T5-style relative positional embeddings to the attention logits. - - Attributes: - num_buckets: Number of buckets to bucket distances between key and query - positions into. - max_distance: Maximum distance before everything is lumped into the last - distance bucket. - num_heads: Number of heads in the attention layer. Each head will get a - different relative position weighting. - dtype: Type of arrays through this module. - embedding_init: initializer for relative embedding table. - """ - - num_buckets: int - max_distance: int - num_heads: int - dtype: Any - embedding_init: Callable[..., Array] = nn.linear.default_embed_init - - @staticmethod - def _relative_position_bucket( - relative_position, bidirectional=True, num_buckets=32, max_distance=128 - ): - """Translate relative position to a bucket number for relative attention. - - The relative position is defined as memory_position - query_position, i.e. - the distance in tokens from the attending position to the attended-to - position. If bidirectional=False, then positive relative positions are - invalid. - We use smaller buckets for small absolute relative_position and larger - buckets for larger absolute relative_positions. All relative - positions >=max_distance map to the same bucket. All relative - positions <=-max_distance map to the same bucket. This should allow for - more graceful generalization to longer sequences than the model has been - trained on. - - Args: - relative_position: an int32 array - bidirectional: a boolean - whether the attention is bidirectional - num_buckets: an integer - max_distance: an integer - - Returns: - a Tensor with the same shape as relative_position, containing int32 - values in the range [0, num_buckets) - """ - ret = 0 - n = -relative_position - if bidirectional: - num_buckets //= 2 - ret += (n < 0).astype(np.int32) * num_buckets - n = np.abs(n) - else: - n = np.maximum(n, 0) - # now n is in the range [0, inf) - max_exact = num_buckets // 2 - is_small = n < max_exact - val_if_large = max_exact + ( - np.log(n.astype(np.float32) / max_exact + np.finfo(np.float32).eps) - / np.log(max_distance / max_exact) - * (num_buckets - max_exact) - ).astype(np.int32) - val_if_large = np.minimum(val_if_large, num_buckets - 1) - ret += np.where(is_small, n, val_if_large) - return ret - - @nn.compact - def __call__(self, qlen, klen, bidirectional=True, decode=False): - """Produce relative position embedding attention biases. - - Args: - qlen: attention query length. - klen: attention key length. - bidirectional: whether to allow positive memory-query relative position - embeddings. - decode: whether to cache relative position bias during autoregressive - decoding. - - Returns: - output: `(1, num_heads, q_len, k_len)` attention bias - """ - # bidirectional embeddings don't make sense when decoding (and break cache). - if decode and bidirectional: - raise ValueError( - 'bidirectional RelativePositionBiases are not supported when ' - '`decode=True`.' - ) - - # We only cache the bias if the model was already initialized, i.e. if this - # module is called with `model.apply` and `decode = True`. We raise an error - # if called with `model.init` and `decode = True`, since this can cache - # incorrect positional embeddings produced by random parameters. - is_initialized = self.has_variable('params', 'rel_embedding') - if decode and not is_initialized: - raise ValueError( - 'decode-mode cannot be enabled during init. use model.apply to ' - 'initialize the decoding cache.' - ) - - # Return pre-computed relative position bias in cache during decode steps. - if decode and self.has_variable('cache', 'cached_bias'): - cached_bias = self.get_variable('cache', 'cached_bias') - expected_bias_shape = (1, self.num_heads, qlen, klen) - if cached_bias.shape != expected_bias_shape: - raise ValueError( - 'The cached relative position attention bias was ' - f'expected to have shape {expected_bias_shape} but ' - f'instead has the shape {cached_bias.shape}.' - ) - return cached_bias - - # TODO(levskaya): should we be computing this w. numpy as a program - # constant? - context_position = np.arange(qlen, dtype=jnp.int32)[:, None] - memory_position = np.arange(klen, dtype=jnp.int32)[None, :] - relative_position = memory_position - context_position # shape (qlen, klen) - rp_bucket = self._relative_position_bucket( - relative_position, - bidirectional=bidirectional, - num_buckets=self.num_buckets, - max_distance=self.max_distance, - ) - relative_attention_bias = param_with_axes( - 'rel_embedding', - self.embedding_init, - (self.num_heads, self.num_buckets), - jnp.float32, - axes=('heads', 'relpos_buckets'), - ) - - relative_attention_bias = jnp.asarray(relative_attention_bias, self.dtype) - # Instead of using a slow gather, we create a leading-dimension one-hot - # array from rp_bucket and use it to perform the gather-equivalent via a - # contraction, i.e.: - # (num_head, num_buckets) x (num_buckets one-hot, qlen, klen). - # This is equivalent to relative_attention_bias[:, rp_bucket] - bcast_iota = lax.broadcasted_iota(jnp.int32, (self.num_buckets, 1, 1), 0) - rp_bucket_one_hot = jnp.array( - rp_bucket[jnp.newaxis, ...] == bcast_iota, dtype=self.dtype - ) - # --> shape (qlen, klen, num_heads) - values = lax.dot_general( - relative_attention_bias, - rp_bucket_one_hot, - (((1,), (0,)), ((), ())), # rhs, lhs contracting dims - ) # no batched dims - # Add a singleton batch dimension. - # --> shape (1, num_heads, qlen, klen) - out = values[jnp.newaxis, ...] - - # Store computed relative position bias in cache after first calculation. - if decode: - _ = self.variable('cache', 'cached_bias', lambda: out) - - return out - - -# ------------------------------------------------------------------------------ -# T5 Layernorm - no subtraction of mean or bias. -# ------------------------------------------------------------------------------ -class LayerNorm(nn.Module): - """T5 Layer normalization operating on the last axis of the input data.""" - - epsilon: float = 1e-6 - dtype: Any = jnp.float32 - scale_init: Initializer = nn.initializers.ones - - @nn.compact - def __call__(self, x: jnp.ndarray) -> jnp.ndarray: - """Applies layer normalization on the input.""" - x = jnp.asarray(x, jnp.float32) - features = x.shape[-1] - mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True) - y = jnp.asarray(x * lax.rsqrt(mean2 + self.epsilon), self.dtype) - scale = param_with_axes( - 'scale', self.scale_init, (features,), jnp.float32, axes=('embed',) - ) - - scale = jnp.asarray(scale, self.dtype) - return y * scale - - -# ------------------------------------------------------------------------------ -# Mask-making utility functions. -# ------------------------------------------------------------------------------ -def make_attention_mask( - query_input: Array, - key_input: Array, - pairwise_fn: Callable = jnp.multiply, - extra_batch_dims: int = 0, - dtype: DType = jnp.float32, -) -> Array: - """Mask-making helper for attention weights. - - In case of 1d inputs (i.e., `[batch, len_q]`, `[batch, len_kv]`, the - attention weights will be `[batch, heads, len_q, len_kv]` and this - function will produce `[batch, 1, len_q, len_kv]`. - - Args: - query_input: a batched, flat input of query_length size - key_input: a batched, flat input of key_length size - pairwise_fn: broadcasting elementwise comparison function - extra_batch_dims: number of extra batch dims to add singleton axes for, none - by default - dtype: mask return dtype - - Returns: - A `[batch, 1, len_q, len_kv]` shaped mask for 1d attention. - """ - # [batch, len_q, len_kv] - mask = pairwise_fn( - # [batch, len_q] -> [batch, len_q, 1] - jnp.expand_dims(query_input, axis=-1), - # [batch, len_q] -> [batch, 1, len_kv] - jnp.expand_dims(key_input, axis=-2), - ) - - # [batch, 1, len_q, len_kv]. This creates the head dim. - mask = jnp.expand_dims(mask, axis=-3) - mask = jnp.expand_dims(mask, axis=tuple(range(extra_batch_dims))) - return mask.astype(dtype) - - -def make_causal_mask( - x: Array, extra_batch_dims: int = 0, dtype: DType = jnp.float32 -) -> Array: - """Make a causal mask for self-attention. - - In case of 1d inputs (i.e., `[batch, len]`, the self-attention weights - will be `[batch, heads, len, len]` and this function will produce a - causal mask of shape `[batch, 1, len, len]`. - - Note that a causal mask does not depend on the values of x; it only depends on - the shape. If x has padding elements, they will not be treated in a special - manner. - - Args: - x: input array of shape `[batch, len]` - extra_batch_dims: number of batch dims to add singleton axes for, none by - default - dtype: mask return dtype - - Returns: - A `[batch, 1, len, len]` shaped causal mask for 1d attention. - """ - idxs = jnp.broadcast_to(jnp.arange(x.shape[-1], dtype=jnp.int32), x.shape) - return make_attention_mask( - idxs, - idxs, - jnp.greater_equal, - extra_batch_dims=extra_batch_dims, - dtype=dtype, - ) - - -def combine_masks(*masks: Optional[Array], dtype: DType = jnp.float32): - """Combine attention masks. - - Args: - *masks: set of attention mask arguments to combine, some can be None. - dtype: final mask dtype - - Returns: - Combined mask, reduced by logical and, returns None if no masks given. - """ - masks = [m for m in masks if m is not None] - if not masks: - return None - assert all( - map(lambda x: x.ndim == masks[0].ndim, masks) - ), f'masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}' - mask, *other_masks = masks - for other_mask in other_masks: - mask = jnp.logical_and(mask, other_mask) - return mask.astype(dtype) - - -def combine_biases(*masks: Optional[Array]): - """Combine attention biases. - - Args: - *masks: set of attention bias arguments to combine, some can be None. - - Returns: - Combined mask, reduced by summation, returns None if no masks given. - """ - masks = [m for m in masks if m is not None] - if not masks: - return None - assert all( - map(lambda x: x.ndim == masks[0].ndim, masks) - ), f'masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}' - mask, *other_masks = masks - for other_mask in other_masks: - mask = mask + other_mask - return mask - - -def make_decoder_mask( - decoder_target_tokens: Array, - dtype: DType, - decoder_causal_attention: Optional[Array] = None, - decoder_segment_ids: Optional[Array] = None, -) -> Array: - """Compute the self-attention mask for a decoder. - - Decoder mask is formed by combining a causal mask, a padding mask and an - optional packing mask. If decoder_causal_attention is passed, it makes the - masking non-causal for positions that have value of 1. - - A prefix LM is applied to a dataset which has a notion of "inputs" and - "targets", e.g., a machine translation task. The inputs and targets are - concatenated to form a new target. `decoder_target_tokens` is the concatenated - decoder output tokens. - - The "inputs" portion of the concatenated sequence can attend to other "inputs" - tokens even for those at a later time steps. In order to control this - behavior, `decoder_causal_attention` is necessary. This is a binary mask with - a value of 1 indicating that the position belonged to "inputs" portion of the - original dataset. - - Example: - - Suppose we have a dataset with two examples. - - ds = [{"inputs": [6, 7], "targets": [8]}, - {"inputs": [3, 4], "targets": [5]}] - - After the data preprocessing with packing, the two examples are packed into - one example with the following three fields (some fields are skipped for - simplicity). - - decoder_target_tokens = [[6, 7, 8, 3, 4, 5, 0]] - decoder_segment_ids = [[1, 1, 1, 2, 2, 2, 0]] - decoder_causal_attention = [[1, 1, 0, 1, 1, 0, 0]] - - where each array has [batch, length] shape with batch size being 1. Then, - this function computes the following mask. - - mask = [[[[1, 1, 0, 0, 0, 0, 0], - [1, 1, 0, 0, 0, 0, 0], - [1, 1, 1, 0, 0, 0, 0], - [0, 0, 0, 1, 1, 0, 0], - [0, 0, 0, 1, 1, 0, 0], - [0, 0, 0, 1, 1, 1, 0], - [0, 0, 0, 0, 0, 0, 0]]]] - - mask[b, 1, :, :] represents the mask for the example `b` in the batch. - Because mask is for a self-attention layer, the mask's shape is a square of - shape [query length, key length]. - - mask[b, 1, i, j] = 1 means that the query token at position i can attend to - the key token at position j. - - Args: - decoder_target_tokens: decoder output tokens. [batch, length] - dtype: dtype of the output mask. - decoder_causal_attention: a binary mask indicating which position should - only attend to earlier positions in the sequence. Others will attend - bidirectionally. [batch, length] - decoder_segment_ids: decoder segmentation info for packed examples. [batch, - length] - - Returns: - the combined decoder mask. - """ - masks = [] - # The same mask is applied to all attention heads. So the head dimension is 1, - # i.e., the mask will be broadcast along the heads dim. - # [batch, 1, length, length] - causal_mask = make_causal_mask(decoder_target_tokens, dtype=dtype) - - # Positions with value 1 in `decoder_causal_attneition` can attend - # bidirectionally. - if decoder_causal_attention is not None: - # [batch, 1, length, length] - inputs_mask = make_attention_mask( - decoder_causal_attention, - decoder_causal_attention, - jnp.logical_and, - dtype=dtype, - ) - masks.append(jnp.logical_or(causal_mask, inputs_mask).astype(dtype)) - else: - masks.append(causal_mask) - - # Padding mask. - masks.append( - make_attention_mask( - decoder_target_tokens > 0, decoder_target_tokens > 0, dtype=dtype - ) - ) - - # Packing mask - if decoder_segment_ids is not None: - masks.append( - make_attention_mask( - decoder_segment_ids, decoder_segment_ids, jnp.equal, dtype=dtype - ) - ) - - return combine_masks(*masks, dtype=dtype) # pytype: disable=bad-return-type # jax-ndarray diff --git a/t5x-main/t5x/examples/decoder_only/layers_test.py b/t5x-main/t5x/examples/decoder_only/layers_test.py deleted file mode 100644 index b43d9d547902827685f115ef6017a67facab63c6..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/decoder_only/layers_test.py +++ /dev/null @@ -1,847 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for attention classes.""" - -import dataclasses -from typing import Optional -from unittest import mock - -from absl.testing import absltest -from absl.testing import parameterized -from flax import linen as nn -from flax.core import freeze -from flax.linen import partitioning as nn_partitioning -import jax -from jax import random -from jax.nn import initializers -import jax.numpy as jnp -import numpy as np -from t5x.examples.decoder_only import layers - -# Parse absl flags test_srcdir and test_tmpdir. -jax.config.parse_flags_with_absl() - -Array = jnp.ndarray -AxisMetadata = nn_partitioning.AxisMetadata # pylint: disable=invalid-name - - -class SelfAttention(layers.MultiHeadDotProductAttention): - """Self-attention special case of multi-head dot-product attention.""" - - @nn.compact - def __call__( - self, - inputs_q: Array, - mask: Optional[Array] = None, - bias: Optional[Array] = None, - deterministic: bool = False, - ): - return super().__call__( - inputs_q, inputs_q, mask, bias, deterministic=deterministic - ) - - -@dataclasses.dataclass(frozen=True) -class SelfAttentionArgs: - num_heads: int = 1 - batch_size: int = 2 - # qkv_features: int = 3 - head_dim: int = 3 - # out_features: int = 4 - q_len: int = 5 - features: int = 6 - dropout_rate: float = 0.1 - deterministic: bool = False - decode: bool = False - float32_logits: bool = False - - def __post_init__(self): - # If we are doing decoding, the query length should be 1, because are doing - # autoregressive decoding where we feed one position at a time. - assert not self.decode or self.q_len == 1 - - def init_args(self): - return dict( - num_heads=self.num_heads, - head_dim=self.head_dim, - dropout_rate=self.dropout_rate, - float32_logits=self.float32_logits, - ) - - def apply_args(self): - inputs_q = jnp.ones((self.batch_size, self.q_len, self.features)) - mask = jnp.ones((self.batch_size, self.num_heads, self.q_len, self.q_len)) - bias = jnp.ones((self.batch_size, self.num_heads, self.q_len, self.q_len)) - return { - 'inputs_q': inputs_q, - 'mask': mask, - 'bias': bias, - 'deterministic': self.deterministic, - } - - -class AttentionTest(parameterized.TestCase): - - def test_dot_product_attention_shape(self): - # This test only checks for shape but tries to make sure all code paths are - # reached. - dropout_rng = random.PRNGKey(0) - batch_size, num_heads, q_len, kv_len, qk_depth, v_depth = 1, 2, 3, 4, 5, 6 - - query = jnp.ones((batch_size, q_len, num_heads, qk_depth)) - key = jnp.ones((batch_size, kv_len, num_heads, qk_depth)) - value = jnp.ones((batch_size, kv_len, num_heads, v_depth)) - bias = jnp.ones((batch_size, num_heads, q_len, kv_len)) - - args = dict( - query=query, - key=key, - value=value, - bias=bias, - dropout_rng=dropout_rng, - dropout_rate=0.5, - deterministic=False, - ) - - output = layers.dot_product_attention(**args) - self.assertEqual(output.shape, (batch_size, q_len, num_heads, v_depth)) - - def test_make_attention_mask_multiply_pairwise_fn(self): - decoder_target_tokens = jnp.array([[7, 0, 0], [8, 5, 0]]) - attention_mask = layers.make_attention_mask( - decoder_target_tokens > 0, decoder_target_tokens > 0, dtype=jnp.int32 - ) - expected0 = jnp.array([[1, 0, 0], [0, 0, 0], [0, 0, 0]]) - expected1 = jnp.array([[1, 1, 0], [1, 1, 0], [0, 0, 0]]) - self.assertEqual(attention_mask.shape, (2, 1, 3, 3)) - np.testing.assert_array_equal(attention_mask[0, 0], expected0) - np.testing.assert_array_equal(attention_mask[1, 0], expected1) - - def test_make_attention_mask_equal_pairwise_fn(self): - segment_ids = jnp.array([[1, 1, 2, 2, 2, 0], [1, 1, 1, 2, 0, 0]]) - attention_mask = layers.make_attention_mask( - segment_ids, segment_ids, pairwise_fn=jnp.equal, dtype=jnp.int32 - ) - # Padding is not treated in a special way. So they need to be zeroed out - # separately. - expected0 = jnp.array([ - [1, 1, 0, 0, 0, 0], - [1, 1, 0, 0, 0, 0], - [0, 0, 1, 1, 1, 0], - [0, 0, 1, 1, 1, 0], - [0, 0, 1, 1, 1, 0], - [0, 0, 0, 0, 0, 1], - ]) - expected1 = jnp.array([ - [1, 1, 1, 0, 0, 0], - [1, 1, 1, 0, 0, 0], - [1, 1, 1, 0, 0, 0], - [0, 0, 0, 1, 0, 0], - [0, 0, 0, 0, 1, 1], - [0, 0, 0, 0, 1, 1], - ]) - self.assertEqual(attention_mask.shape, (2, 1, 6, 6)) - np.testing.assert_array_equal(attention_mask[0, 0], expected0) - np.testing.assert_array_equal(attention_mask[1, 0], expected1) - - def test_make_causal_mask_with_padding(self): - x = jnp.array([[7, 0, 0], [8, 5, 0]]) - y = layers.make_causal_mask(x) - self.assertEqual(y.shape, (2, 1, 3, 3)) - # Padding is not treated in a special way. So they need to be zeroed out - # separately. - expected_y = jnp.array( - [[[1.0, 0.0, 0.0], [1.0, 1.0, 0.0], [1.0, 1.0, 1.0]]], jnp.float32 - ) - np.testing.assert_allclose(y[0], expected_y) - np.testing.assert_allclose(y[1], expected_y) - - def test_make_causal_mask_extra_batch_dims(self): - x = jnp.ones((3, 3, 5)) - y = layers.make_causal_mask(x, extra_batch_dims=2) - self.assertEqual(y.shape, (1, 1, 3, 3, 1, 5, 5)) - - def test_make_causal_mask(self): - x = jnp.ones((1, 3)) - y = layers.make_causal_mask(x) - self.assertEqual(y.shape, (1, 1, 3, 3)) - expected_y = jnp.array( - [[[[1.0, 0.0, 0.0], [1.0, 1.0, 0.0], [1.0, 1.0, 1.0]]]], jnp.float32 - ) - np.testing.assert_allclose(y, expected_y) - - def test_combine_masks(self): - masks = [ - jnp.array([0, 1, 0, 1], jnp.float32), - None, - jnp.array([1, 1, 1, 1], jnp.float32), - jnp.array([1, 1, 1, 0], jnp.float32), - ] - y = layers.combine_masks(*masks) - np.testing.assert_allclose(y, jnp.array([0, 1, 0, 0], jnp.float32)) - - def test_combine_biases(self): - masks = [ - jnp.array([0, 1, 0, 1], jnp.float32), - None, - jnp.array([0, 1, 1, 1], jnp.float32), - jnp.array([0, 1, 1, 0], jnp.float32), - ] - y = layers.combine_biases(*masks) - np.testing.assert_allclose(y, jnp.array([0, 3, 2, 2], jnp.float32)) - - def test_make_decoder_mask_lm_unpacked(self): - decoder_target_tokens = jnp.array([6, 7, 3, 0]) - mask = layers.make_decoder_mask( - decoder_target_tokens=decoder_target_tokens, dtype=jnp.float32 - ) - expected_mask = jnp.array( - [[[1, 0, 0, 0], [1, 1, 0, 0], [1, 1, 1, 0], [0, 0, 0, 0]]] - ) - np.testing.assert_array_equal(mask, expected_mask) - - def test_make_decoder_mask_lm_packed(self): - decoder_target_tokens = jnp.array([[6, 7, 3, 4, 5, 0]]) - decoder_segment_ids = jnp.array([[1, 1, 1, 2, 2, 0]]) - mask = layers.make_decoder_mask( - decoder_target_tokens=decoder_target_tokens, - dtype=jnp.float32, - decoder_segment_ids=decoder_segment_ids, - ) - expected_mask = jnp.array([[[ - [1, 0, 0, 0, 0, 0], - [1, 1, 0, 0, 0, 0], - [1, 1, 1, 0, 0, 0], - [0, 0, 0, 1, 0, 0], - [0, 0, 0, 1, 1, 0], - [0, 0, 0, 0, 0, 0], - ]]]) - np.testing.assert_array_equal(mask, expected_mask) - - def test_make_decoder_mask_prefix_lm_unpacked(self): - decoder_target_tokens = jnp.array([[5, 6, 7, 3, 4, 0]]) - decoder_causal_attention = jnp.array([[1, 1, 1, 0, 0, 0]]) - mask = layers.make_decoder_mask( - decoder_target_tokens=decoder_target_tokens, - dtype=jnp.float32, - decoder_causal_attention=decoder_causal_attention, - ) - expected_mask = jnp.array( - [[[ - [1, 1, 1, 0, 0, 0], - [1, 1, 1, 0, 0, 0], - [1, 1, 1, 0, 0, 0], - [1, 1, 1, 1, 0, 0], - [1, 1, 1, 1, 1, 0], - [0, 0, 0, 0, 0, 0], - ]]], - dtype=jnp.float32, - ) - np.testing.assert_array_equal(mask, expected_mask) - - def test_make_decoder_mask_prefix_lm_packed(self): - decoder_target_tokens = jnp.array([[5, 6, 7, 8, 3, 4, 0]]) - decoder_segment_ids = jnp.array([[1, 1, 1, 2, 2, 2, 0]]) - decoder_causal_attention = jnp.array([[1, 1, 0, 1, 1, 0, 0]]) - mask = layers.make_decoder_mask( - decoder_target_tokens=decoder_target_tokens, - dtype=jnp.float32, - decoder_causal_attention=decoder_causal_attention, - decoder_segment_ids=decoder_segment_ids, - ) - expected_mask = jnp.array([[[ - [1, 1, 0, 0, 0, 0, 0], - [1, 1, 0, 0, 0, 0, 0], - [1, 1, 1, 0, 0, 0, 0], - [0, 0, 0, 1, 1, 0, 0], - [0, 0, 0, 1, 1, 0, 0], - [0, 0, 0, 1, 1, 1, 0], - [0, 0, 0, 0, 0, 0, 0], - ]]]) - np.testing.assert_array_equal(mask, expected_mask) - - def test_make_decoder_mask_prefix_lm_unpacked_multiple_elements(self): - decoder_target_tokens = jnp.array([[6, 7, 3, 0], [4, 5, 0, 0]]) - decoder_causal_attention = jnp.array([[1, 1, 0, 0], [1, 0, 0, 0]]) - mask = layers.make_decoder_mask( - decoder_target_tokens=decoder_target_tokens, - dtype=jnp.float32, - decoder_causal_attention=decoder_causal_attention, - ) - expected_mask0 = jnp.array( - [[1, 1, 0, 0], [1, 1, 0, 0], [1, 1, 1, 0], [0, 0, 0, 0]] - ) - expected_mask1 = jnp.array( - [[1, 0, 0, 0], [1, 1, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]] - ) - self.assertEqual(mask.shape, (2, 1, 4, 4)) - np.testing.assert_array_equal(mask[0, 0], expected_mask0) - np.testing.assert_array_equal(mask[1, 0], expected_mask1) - - def test_make_decoder_mask_composite_causal_attention(self): - decoder_target_tokens = jnp.array([[6, 7, 3, 4, 8, 9, 0]]) - decoder_causal_attention = jnp.array([[1, 1, 0, 0, 1, 1, 0]]) - mask = layers.make_decoder_mask( - decoder_target_tokens=decoder_target_tokens, - dtype=jnp.float32, - decoder_causal_attention=decoder_causal_attention, - ) - expected_mask0 = jnp.array([ - [1, 1, 0, 0, 1, 1, 0], - [1, 1, 0, 0, 1, 1, 0], - [1, 1, 1, 0, 0, 0, 0], - [1, 1, 1, 1, 0, 0, 0], - [1, 1, 1, 1, 1, 1, 0], - [1, 1, 1, 1, 1, 1, 0], - [0, 0, 0, 0, 0, 0, 0], - ]) - - self.assertEqual(mask.shape, (1, 1, 7, 7)) - np.testing.assert_array_equal(mask[0, 0], expected_mask0) - - def test_make_decoder_mask_composite_causal_attention_packed(self): - decoder_target_tokens = jnp.array([[6, 7, 3, 4, 8, 9, 2, 3, 4]]) - decoder_segment_ids = jnp.array([[1, 1, 1, 1, 1, 1, 2, 2, 2]]) - decoder_causal_attention = jnp.array([[1, 1, 0, 0, 1, 1, 1, 1, 0]]) - mask = layers.make_decoder_mask( - decoder_target_tokens=decoder_target_tokens, - dtype=jnp.float32, - decoder_causal_attention=decoder_causal_attention, - decoder_segment_ids=decoder_segment_ids, - ) - expected_mask0 = jnp.array([ - [1, 1, 0, 0, 1, 1, 0, 0, 0], - [1, 1, 0, 0, 1, 1, 0, 0, 0], - [1, 1, 1, 0, 0, 0, 0, 0, 0], - [1, 1, 1, 1, 0, 0, 0, 0, 0], - [1, 1, 1, 1, 1, 1, 0, 0, 0], - [1, 1, 1, 1, 1, 1, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 1, 1, 0], - [0, 0, 0, 0, 0, 0, 1, 1, 0], - [0, 0, 0, 0, 0, 0, 1, 1, 1], - ]) - - self.assertEqual(mask.shape, (1, 1, 9, 9)) - np.testing.assert_array_equal(mask[0, 0], expected_mask0) - - @parameterized.parameters({'f': 20}, {'f': 22}) - def test_multihead_dot_product_attention(self, f): - # b: batch, f: emb_dim, q: q_len, k: kv_len, h: num_head, d: head_dim - b, q, h, d, k = 2, 3, 4, 5, 6 - - base_args = SelfAttentionArgs(num_heads=h, head_dim=d, dropout_rate=0) - args = base_args.init_args() - - np.random.seed(0) - inputs_q = np.random.randn(b, q, f) - inputs_kv = np.random.randn(b, k, f) - - # Projection: [b, q, f] -> [b, q, h, d] - # So the kernels have to be [f, h, d] - query_kernel = np.random.randn(f, h, d) - key_kernel = np.random.randn(f, h, d) - value_kernel = np.random.randn(f, h, d) - # `out` calculation: [b, q, h, d] -> [b, q, f] - # So kernel has to be [h, d, f] - out_kernel = np.random.randn(h, d, f) - - params = { - 'query': {'kernel': query_kernel.reshape(f, -1)}, - 'key': {'kernel': key_kernel.reshape(f, -1)}, - 'value': {'kernel': value_kernel.reshape(f, -1)}, - 'out': {'kernel': out_kernel.reshape(-1, f)}, - } - y = layers.MultiHeadDotProductAttention(**args).apply( - {'params': freeze(params)}, inputs_q, inputs_kv - ) - - query = np.einsum('bqf,fhd->bqhd', inputs_q, query_kernel) - key = np.einsum('bkf,fhd->bkhd', inputs_kv, key_kernel) - value = np.einsum('bkf,fhd->bkhd', inputs_kv, value_kernel) - logits = np.einsum('bqhd,bkhd->bhqk', query, key) - weights = nn.softmax(logits, axis=-1) - combined_value = np.einsum('bhqk,bkhd->bqhd', weights, value) - y_expected = np.einsum('bqhd,hdf->bqf', combined_value, out_kernel) - np.testing.assert_allclose(y, y_expected, rtol=1e-5, atol=1e-5) - - def test_multihead_dot_product_attention_caching(self): - # b: batch, f: qkv_features, k: kv_len, h: num_head, d: head_dim - b, h, d, k = 2, 3, 4, 5 - f = h * d - - base_args = SelfAttentionArgs(num_heads=h, head_dim=d, dropout_rate=0) - args = base_args.init_args() - - cache = { - 'cached_key': np.zeros((b, h, d, k)), - 'cached_value': np.zeros((b, h, d, k)), - 'cache_index': np.array(0), - } - inputs_q = np.random.randn(b, 1, f) - inputs_kv = np.random.randn(b, 1, f) - - # Mock dense general such that q, k, v projections are replaced by simple - # reshaping. - def mock_dense_general(self, x, **kwargs): # pylint: disable=unused-argument - return x.reshape(b, -1, h, d) - - with mock.patch.object( - layers.DenseGeneral, '__call__', new=mock_dense_general - ): - _, mutated = layers.MultiHeadDotProductAttention(**args).apply( - {'cache': freeze(cache)}, - inputs_q, - inputs_kv, - decode=True, - mutable=['cache'], - ) - updated_cache = mutated['cache'] - - # Perform the same mocked projection to generate the expected cache. - # (key|value): [b, 1, h, d] - key = mock_dense_general(None, inputs_kv) - value = mock_dense_general(None, inputs_kv) - - # cached_(key|value): [b, h, d, k] - cache['cached_key'][:, :, :, 0] = key[:, 0, :, :] - cache['cached_value'][:, :, :, 0] = value[:, 0, :, :] - cache['cache_index'] = np.array(1) - for name, array in cache.items(): - np.testing.assert_allclose(array, updated_cache[name]) - - def test_dot_product_attention(self): - # b: batch, f: emb_dim, q: q_len, k: kv_len, h: num_head, d: head_dim - b, q, h, d, k = 2, 3, 4, 5, 6 - np.random.seed(0) - query = np.random.randn(b, q, h, d) - key = np.random.randn(b, k, h, d) - value = np.random.randn(b, k, h, d) - bias = np.random.randn(b, h, q, k) - attn_out = layers.dot_product_attention(query, key, value, bias=bias) - logits = np.einsum('bqhd,bkhd->bhqk', query, key) - weights = jax.nn.softmax(logits + bias, axis=-1) - expected = np.einsum('bhqk,bkhd->bqhd', weights, value) - np.testing.assert_allclose(attn_out, expected, atol=1e-6) - - def test_multihead_dot_product_attention_prefill_caching(self): - # b: batch, f: qkv_features, k: kv_len, h: num_head, d: head_dim - b, h, d, k = 2, 3, 4, 5 - f = h * d - prefill_lengths = np.array([3, 1]) - - base_args = SelfAttentionArgs(num_heads=h, head_dim=d, dropout_rate=0) - args = base_args.init_args() - - cache = { - 'cached_key': np.zeros((b, h, d, k)), - 'cached_value': np.zeros((b, h, d, k)), - 'cache_index': np.array([0, 0]), - } - inputs_q = np.random.randn(b, k, f) - inputs_kv = np.random.randn(b, k, f) - - # Mock dense general such that q, k, v projections are replaced by simple - # reshaping. - def mock_dense_general(self, x, **kwargs): # pylint: disable=unused-argument - return x.reshape(b, -1, h, d) - - with mock.patch.object( - layers.DenseGeneral, '__call__', new=mock_dense_general - ): - _, mutated = layers.MultiHeadDotProductAttention(**args).apply( - {'cache': freeze(cache)}, - inputs_q, - inputs_kv, - decode=False, - prefill=True, - prefill_lengths=prefill_lengths, - mutable=['cache'], - ) - updated_cache = mutated['cache'] - - # Perform the same mocked projection to generate the expected cache. - # (key|value): [b, 1, h, d] - key = mock_dense_general(None, inputs_kv) - value = mock_dense_general(None, inputs_kv) - - # cached_(key|value): [b, h, d, k] - # Update the our gold cache with the key and values that are part of the - # prefix that we are prefilling the cache with. Explicit loops here avoid a - # confusing transpose. - for b, prefill_length in enumerate(prefill_lengths): - for i in range(prefill_length): - cache['cached_key'][b, :, :, i] = key[b, i, :, :] - cache['cached_value'][b, :, :, i] = value[b, i, :, :] - cache['cache_index'][b] = prefill_length - for name, array in cache.items(): - np.testing.assert_allclose(array, updated_cache[name]) - - -class EmbeddingTest(parameterized.TestCase): - - def test_embedder_raises_exception_for_incorrect_input_type(self): - """Tests that inputs are integers and that an exception is raised if not.""" - embed = layers.Embed(num_embeddings=10, features=5) - inputs = np.expand_dims(np.arange(5, dtype=np.int64), 1) - variables = embed.init(jax.random.PRNGKey(0), inputs) - bad_inputs = inputs.astype(np.float32) - with self.assertRaisesRegex( - ValueError, 'Input type must be an integer or unsigned integer.' - ): - _ = embed.apply(variables, bad_inputs) - - @parameterized.named_parameters( - { - 'testcase_name': 'with_ones', - 'init_fn': jax.nn.initializers.ones, - 'num_embeddings': 10, - 'features': 5, - 'matrix_sum': 5 * 10, - }, - { - 'testcase_name': 'with_zeros', - 'init_fn': jax.nn.initializers.zeros, - 'num_embeddings': 10, - 'features': 5, - 'matrix_sum': 0, - }, - ) - def test_embedding_initializes_correctly( - self, init_fn, num_embeddings, features, matrix_sum - ): - """Tests if the Embed class initializes with the requested initializer.""" - embed = layers.Embed( - num_embeddings=num_embeddings, features=features, embedding_init=init_fn - ) - inputs = np.expand_dims(np.arange(5, dtype=np.int64), 1) - variables = embed.init(jax.random.PRNGKey(0), inputs) - embedding_matrix = variables['params']['embedding'] - self.assertEqual(int(np.sum(embedding_matrix)), matrix_sum) - - def test_embedding_matrix_shape(self): - """Tests that the embedding matrix has the right shape.""" - num_embeddings = 10 - features = 5 - embed = layers.Embed(num_embeddings=num_embeddings, features=features) - inputs = np.expand_dims(np.arange(features, dtype=np.int64), 1) - variables = embed.init(jax.random.PRNGKey(0), inputs) - embedding_matrix = variables['params']['embedding'] - self.assertEqual((num_embeddings, features), embedding_matrix.shape) - - def test_embedding_attend(self): - """Tests that attending with ones returns sum of embedding vectors.""" - features = 5 - embed = layers.Embed(num_embeddings=10, features=features) - inputs = np.array([[1]], dtype=np.int64) - variables = embed.init(jax.random.PRNGKey(0), inputs) - query = np.ones(features, dtype=np.float32) - result = embed.apply(variables, query, method=embed.attend) - expected = np.sum(variables['params']['embedding'], -1) - np.testing.assert_array_almost_equal(result, expected) - - -class DenseTest(parameterized.TestCase): - - def test_dense_general_no_bias(self): - rng = random.PRNGKey(0) - x = jnp.ones((1, 3)) - model = layers.DenseGeneral( - features=4, - kernel_init=initializers.ones, - ) - y, _ = model.init_with_output(rng, x) - self.assertEqual(y.shape, (1, 4)) - np.testing.assert_allclose(y, np.full((1, 4), 3.0)) - - def test_dense_general_two_features(self): - rng = random.PRNGKey(0) - x = jnp.ones((1, 3)) - model = layers.DenseGeneral( - features=(2, 2), - kernel_init=initializers.ones, - ) - y, _ = model.init_with_output(rng, x) - # We transform the last input dimension to two output dimensions (2, 2). - np.testing.assert_allclose(y, np.full((1, 2, 2), 3.0)) - - def test_dense_general_two_axes(self): - rng = random.PRNGKey(0) - x = jnp.ones((1, 2, 2)) - model = layers.DenseGeneral( - features=3, - axis=(-2, 2), # Note: this is the same as (1, 2). - kernel_init=initializers.ones, - ) - y, _ = model.init_with_output(rng, x) - # We transform the last two input dimensions (2, 2) to one output dimension. - np.testing.assert_allclose(y, np.full((1, 3), 4.0)) - - def test_mlp_same_out_dim(self): - module = layers.MlpBlock( - intermediate_dim=4, - activations=('relu',), - kernel_init=nn.initializers.xavier_uniform(), - dtype=jnp.float32, - ) - inputs = np.array( - [ - # Batch 1. - [[1, 1], [1, 1], [1, 2]], - # Batch 2. - [[2, 2], [3, 1], [2, 2]], - ], - dtype=np.float32, - ) - params = module.init(random.PRNGKey(0), inputs, deterministic=True) - # self.assertEqual( - # jax.tree.map(lambda a: a.tolist(), params), - # { - # 'params': { - # 'wi': { - # 'kernel': [ - # [ - # -0.8675811290740967, - # 0.08417510986328125, - # 0.022586345672607422, - # -0.9124102592468262, - # ], - # [ - # -0.19464373588562012, - # 0.49809837341308594, - # 0.7808468341827393, - # 0.9267289638519287, - # ], - # ], - # }, - # 'wo': { - # 'kernel': [ - # [0.01154780387878418, 0.1397249698638916], - # [0.974980354309082, 0.5903260707855225], - # [-0.05997943878173828, 0.616570234298706], - # [0.2934272289276123, 0.8181164264678955], - # ], - # }, - # }, - # 'params_axes': { - # 'wi': { - # 'kernel_axes': AxisMetadata(names=('embed', 'mlp')), - # }, - # 'wo': { - # 'kernel_axes': AxisMetadata(names=('mlp', 'embed')), - # }, - # }, - # }, - # ) - result = module.apply(params, inputs, deterministic=True) # pylint: disable=unused-variable - # np.testing.assert_allclose( - # result.tolist(), - # [ - # [ - # [0.5237172245979309, 0.8508185744285583], - # [0.5237172245979309, 0.8508185744285583], - # [1.2344461679458618, 2.3844780921936035], - # ], - # [ - # [1.0474344491958618, 1.7016371488571167], - # [0.6809444427490234, 0.9663378596305847], - # [1.0474344491958618, 1.7016371488571167], - # ], - # ], - # rtol=1e-6, - # ) - - -class RelativePositionBiasesTest(absltest.TestCase): - - def setUp(self): - self.num_heads = 3 - self.query_len = 5 - self.key_len = 7 - self.relative_attention = layers.RelativePositionBiases( - num_buckets=12, - max_distance=10, - num_heads=3, - dtype=jnp.float32, - ) - super(RelativePositionBiasesTest, self).setUp() - - def test_relative_attention_bidirectional_params(self): - """Tests that bidirectional relative position biases have expected params.""" - params = self.relative_attention.init( - random.PRNGKey(0), self.query_len, self.key_len, bidirectional=True - ) - param_shapes = jax.tree.map(lambda x: x.shape, params) - self.assertEqual( - param_shapes, - { - 'params': { - 'rel_embedding': (3, 12), - }, - 'params_axes': { - 'rel_embedding_axes': AxisMetadata( - names=('heads', 'relpos_buckets') - ), - }, - }, - ) - - def test_regression_relative_attention_bidirectional_values(self): - """Tests that bidirectional relative position biases match expected values. - - See top docstring note on matching T5X behavior for these regression tests. - """ - outputs, unused_params = self.relative_attention.init_with_output( - random.PRNGKey(0), self.query_len, self.key_len, bidirectional=True - ) - self.assertEqual( - outputs.shape, (1, self.num_heads, self.query_len, self.key_len) - ) - # self.assertAlmostEqual(outputs[0, 0, 0, 0], 0.55764728, places=5) - # self.assertAlmostEqual(outputs[0, 1, 2, 1], -0.10935841, places=5) - # self.assertAlmostEqual(outputs[0, 1, 4, 6], 0.14510104, places=5) - # self.assertAlmostEqual(outputs[0, 2, 4, 6], -0.36783996, places=5) - - def test_relative_attention_unidirectional_params(self): - """Tests that unidirectional relative position biases have expected params.""" - params = self.relative_attention.init( - random.PRNGKey(0), self.query_len, self.key_len, bidirectional=False - ) - param_shapes = jax.tree.map(lambda x: x.shape, params) - self.assertEqual( - param_shapes, - { - 'params': { - 'rel_embedding': (3, 12), - }, - 'params_axes': { - 'rel_embedding_axes': AxisMetadata( - names=('heads', 'relpos_buckets') - ), - }, - }, - ) - - def test_regression_relative_attention_unidirectional_values(self): - """Tests that unidirectional relative position biases match expected values. - - See top docstring note on matching T5X behavior for these regression tests. - """ - outputs, unused_params = self.relative_attention.init_with_output( - random.PRNGKey(0), self.query_len, self.key_len, bidirectional=False - ) - self.assertEqual( - outputs.shape, (1, self.num_heads, self.query_len, self.key_len) - ) - # self.assertAlmostEqual(outputs[0, 0, 0, 0], 0.55764728, places=5) - # self.assertAlmostEqual(outputs[0, 1, 2, 1], -0.10935841, places=5) - # self.assertAlmostEqual(outputs[0, 1, 4, 6], -0.13101986, places=5) - # self.assertAlmostEqual(outputs[0, 2, 4, 6], 0.39296466, places=5) - - def test_relative_attention_decode_cache_error_with_init(self): - """Tests that relative embedding init fails with decode == True.""" - with self.assertRaisesRegex( - ValueError, - 'decode-mode cannot be enabled during init. use model.apply to ' - 'initialize the decoding cache.', - ): - self.relative_attention.init( - jax.random.PRNGKey(0), - self.query_len, - self.key_len, - bidirectional=False, - decode=True, - ) - - def test_relative_attention_decode_cache_errror_with_bidirectional(self): - """Tests that bidirectional relative embeddings fails when decoding.""" - params = self.relative_attention.init( - jax.random.PRNGKey(0), - self.query_len, - self.key_len, - bidirectional=False, - decode=False, - ) - - with self.assertRaisesRegex( - ValueError, - 'bidirectional RelativePositionBiases are not supported when ' - '`decode=True`.', - ): - self.relative_attention.apply( - params, - self.query_len, - self.key_len, - bidirectional=True, - decode=True, - mutable=['cache'], - ) - - def test_relative_attention_decode_cache(self): - """Tests that relative embeddings are correctly cached when decode=True.""" - - params = self.relative_attention.init( - jax.random.PRNGKey(0), - self.query_len, - self.key_len, - bidirectional=False, - decode=False, - ) - - # during init, cache is not actually initialized. - self.assertNotIn('cache', params) - - outputs, state = self.relative_attention.apply( - params, - self.query_len, - self.key_len, - bidirectional=False, - decode=True, - mutable=['cache'], - ) - - self.assertEqual( - outputs.shape, (1, self.num_heads, self.query_len, self.key_len) - ) - - self.assertIn('cached_bias', state['cache']) - - cached_bias = state['cache']['cached_bias'] - - # self.assertAlmostEqual(cached_bias[0, 0, 0, 0], 0.55764728, places=5) - # self.assertAlmostEqual(cached_bias[0, 1, 2, 1], -0.10935841, places=5) - # self.assertAlmostEqual(cached_bias[0, 1, 4, 6], -0.13101986, places=5) - # self.assertAlmostEqual(cached_bias[0, 2, 4, 6], 0.39296466, places=5) - - np.testing.assert_array_equal(outputs, state['cache']['cached_bias']) - - params_with_cache = { - **params, - **state, - } - - outputs, state = self.relative_attention.apply( - params_with_cache, - self.query_len, - self.key_len, - bidirectional=False, - decode=True, - mutable=['cache'], - ) - - np.testing.assert_array_equal(cached_bias, state['cache']['cached_bias']) - - -if __name__ == '__main__': - absltest.main() diff --git a/t5x-main/t5x/examples/decoder_only/models/base.gin b/t5x-main/t5x/examples/decoder_only/models/base.gin deleted file mode 100644 index d0bed734241f03e1066b357882d96d72d162ee48..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/decoder_only/models/base.gin +++ /dev/null @@ -1,59 +0,0 @@ -# Decoder-only model (Base) with 134307072 parameters. -from __gin__ import dynamic_registration - -import seqio -from t5x import adafactor -from t5x import decoding -from t5x import models -from t5x.examples.decoder_only import network - -# ------------------- Loss HParam ---------------------------------------------- -Z_LOSS = 0.0001 -LABEL_SMOOTHING = 0.0 -# NOTE: When fine-tuning the public T5 checkpoints (trained in T5 MeshTF) -# the loss normalizing factor should be set to pretraining batch_size * -# target_token_length. -LOSS_NORMALIZING_FACTOR = None -# Dropout should be specified in the "run" files -DROPOUT_RATE = %gin.REQUIRED - -# Vocabulary (shared by encoder and decoder) -VOCABULARY = @seqio.SentencePieceVocabulary() -seqio.SentencePieceVocabulary.sentencepiece_model_file = "gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model" - -# ------------------- Optimizer ------------------------------------------------ -# `learning_rate` is set by `Trainer.learning_rate_fn`. -OPTIMIZER = @adafactor.Adafactor() -adafactor.Adafactor: - decay_rate = 0.8 - step_offset = 0 - logical_factor_rules = @adafactor.standard_logical_factor_rules() - -# ------------------- Model ---------------------------------------------------- -MODEL = @models.DecoderOnlyModel() -models.DecoderOnlyModel: - module = @network.DecoderWrapper() - vocabulary = %VOCABULARY - optimizer_def = %OPTIMIZER - decode_fn = @decoding.temperature_sample - z_loss = %Z_LOSS - label_smoothing = %LABEL_SMOOTHING - loss_normalizing_factor = %LOSS_NORMALIZING_FACTOR - -decoding.temperature_sample: - temperature = 1.0 - topk = 40 - -# ------------------- Network specification ------------------------------------ -network.DecoderWrapper.config = @network.TransformerConfig() -network.TransformerConfig: - vocab_size = 32128 # vocab size rounded to a multiple of 128 for TPU efficiency - dtype = 'bfloat16' - emb_dim = 768 - num_heads = 12 - num_layers = 12 - head_dim = 64 - mlp_dim = 2048 - mlp_activations = ('gelu', 'linear') - dropout_rate = %DROPOUT_RATE - logits_via_embedding = True diff --git a/t5x-main/t5x/examples/decoder_only/models/large.gin b/t5x-main/t5x/examples/decoder_only/models/large.gin deleted file mode 100644 index 381e44519a5b716b9727a63875feacc2b2b2a56a..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/decoder_only/models/large.gin +++ /dev/null @@ -1,10 +0,0 @@ -include 't5x/examples/decoder_only/models/base.gin' # imports vocab, optimizer and model. - -# ------------------- Network specification overrides -------------------------- -# Parameters obtained from similar config for encoder-decoder model: large.gin -network.TransformerConfig: - emb_dim = 1024 - num_heads = 16 - num_layers = 24 - head_dim = 64 - mlp_dim = 2816 diff --git a/t5x-main/t5x/examples/decoder_only/models/xl.gin b/t5x-main/t5x/examples/decoder_only/models/xl.gin deleted file mode 100644 index c0102ca9bb7432ab6a379c4884690e8e7df904f9..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/decoder_only/models/xl.gin +++ /dev/null @@ -1,11 +0,0 @@ -# Decoder-only model (XL) with 764274688 parameters. - -include 't5x/examples/decoder_only/models/base.gin' # imports vocab, optimizer and model. - -# ------------------- Network specification overrides -------------------------- -network.TransformerConfig: - emb_dim = 2048 - num_heads = 32 - num_layers = 24 - head_dim = 64 - mlp_dim = 5120 diff --git a/t5x-main/t5x/examples/decoder_only/models/xxl.gin b/t5x-main/t5x/examples/decoder_only/models/xxl.gin deleted file mode 100644 index bd4a6b5c541ddd89d7f80297c047bd9784a3cb59..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/decoder_only/models/xxl.gin +++ /dev/null @@ -1,11 +0,0 @@ -# Decoder-only model (XXL) with 4762357760 parameters. - -include 't5x/examples/decoder_only/models/base.gin' # imports vocab, optimizer and model. - -# ------------------- Network specification overrides -------------------------- -network.TransformerConfig: - emb_dim = 4096 - num_heads = 64 - num_layers = 24 - head_dim = 64 - mlp_dim = 10240 diff --git a/t5x-main/t5x/examples/decoder_only/network.py b/t5x-main/t5x/examples/decoder_only/network.py deleted file mode 100644 index bcddd95d6b1e480858482a5abf37791e65603d03..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/decoder_only/network.py +++ /dev/null @@ -1,257 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Minimal decoder-only Transformer model.""" - -from typing import Any, Optional, Sequence - -from flax import linen as nn -from flax import struct -import jax.numpy as jnp -from t5x.examples.decoder_only import layers - - -@struct.dataclass -class TransformerConfig: - """Global hyperparameters used to minimize obnoxious kwarg plumbing.""" - - vocab_size: int - # Activation dtypes. - dtype: Any = jnp.float32 - emb_dim: int = 512 - num_heads: int = 8 - num_layers: int = 6 - head_dim: int = 64 - mlp_dim: int = 2048 - # Activation functions are retrieved from Flax. - mlp_activations: Sequence[str] = ('relu',) - dropout_rate: float = 0.1 - # If `True`, the embedding weights are used in the decoder output layer. - logits_via_embedding: bool = False - - -class DecoderLayer(nn.Module): - """Transformer decoder layer.""" - - config: TransformerConfig - relative_embedding: nn.Module - - @nn.compact - def __call__( - self, - inputs: jnp.ndarray, - decoder_mask: Optional[jnp.ndarray] = None, - deterministic: bool = False, - decode: bool = False, - max_decode_length: Optional[int] = None, - prefill: bool = False, - prefill_lengths: Optional[jnp.ndarray] = None, - ): - """Applies decoder block module.""" - cfg = self.config - - # Relative position embedding as attention biases. - l = max_decode_length if decode and max_decode_length else inputs.shape[-2] - - # During decoding, this module will be called with `decode=True` first to - # initialize the decoder cache, including a cached relpos bias. The prefill - # codepath will call this once again with `decode=False`, which is slightly - # wasteful but generally harmless. During subsequent decode steps, this will - # be called with `decode=True` and will reuse the cached bias. This - # significantly improves performance during decoding with many decode steps. - decoder_bias = self.relative_embedding(l, l, False, decode=decode) - - # `inputs` is layer input with a shape [batch, length, emb_dim]. - x = layers.LayerNorm(dtype=cfg.dtype, name='pre_self_attention_layer_norm')( - inputs - ) - - # Self-attention block - x = layers.MultiHeadDotProductAttention( - num_heads=cfg.num_heads, - dtype=cfg.dtype, - head_dim=cfg.head_dim, - dropout_rate=cfg.dropout_rate, - name='self_attention', - )( - x, - x, - decoder_mask, - decoder_bias, - deterministic=deterministic, - decode=decode, - prefill=prefill, - prefill_lengths=prefill_lengths, - ) - x = nn.Dropout( - rate=cfg.dropout_rate, - broadcast_dims=(-2,), - name='post_self_attention_dropout', - )(x, deterministic=deterministic) - x = x + inputs - - # MLP block. - y = layers.LayerNorm(dtype=cfg.dtype, name='pre_mlp_layer_norm')(x) - y = layers.MlpBlock( - intermediate_dim=cfg.mlp_dim, - activations=cfg.mlp_activations, - intermediate_dropout_rate=cfg.dropout_rate, - dtype=cfg.dtype, - name='mlp', - )(y, deterministic=deterministic) - y = nn.Dropout( - rate=cfg.dropout_rate, broadcast_dims=(-2,), name='post_mlp_dropout' - )(y, deterministic=deterministic) - y = y + x - - return y - - -class Decoder(nn.Module): - """A stack of decoder layers.""" - - config: TransformerConfig - - @nn.compact - def __call__( - self, - decoder_input_tokens: jnp.ndarray, - decoder_target_tokens: jnp.ndarray, - decoder_segment_ids: Optional[jnp.ndarray] = None, - decoder_positions: Optional[jnp.ndarray] = None, - decoder_causal_attention: Optional[jnp.ndarray] = None, - *, - enable_dropout: bool = True, - decode: bool = False, - max_decode_length: Optional[int] = None, - prefill: Optional[bool] = None, - prefill_lengths: Optional[jnp.ndarray] = None, - ): - """Applies LanguageModel on the inputs. - - For a decoder-only architecture with the notion of "prefix", e.g., a prefix - LM where the prefix corresponds to the "inputs" of a supervised dataset, we - perform the "prefill" operation to fill the autoregressive cache - corresponding to the prefix region in one go. Then the autoregressive - decoding starts after the prefix. This makes the decoding process more - efficient. In addition, it gives an option to use bidirectional attention in - the prefix region because the cache is filled simultaneously. - - Args: - decoder_input_tokens: input token to the decoder. - decoder_target_tokens: target token to the decoder. - decoder_segment_ids: decoder segmentation info for packed examples. - decoder_positions: decoder subsequence positions for packed examples. - decoder_causal_attention: a binary mask indicating the portion of the - sequence to apply bidirectional attention to instead of causal. As an - example, useful to specify the "inputs" portion of a concatenated - sequence for a prefix LM. - enable_dropout: enables dropout if set to True. - decode: whether to prepare and use an autoregressive cache as opposed to - using teacher-forcing. - max_decode_length: maximum sequence length to be decoded. - prefill: whether to run a partial sequence to prefill the cache. - prefill_lengths: an array of shape [batch] denoting the length of each - partial sequence we are filling in the cache. - - Returns: - logits array. - """ - cfg = self.config - deterministic = not enable_dropout - assert decoder_input_tokens.ndim == 2 # [batch, len] - rel_emb = layers.RelativePositionBiases( - num_buckets=32, - max_distance=128, - num_heads=cfg.num_heads, - dtype=cfg.dtype, - embedding_init=nn.initializers.variance_scaling( - 1.0, 'fan_avg', 'uniform' - ), - name='relpos_bias', - ) - - if decode: - decoder_mask = None - else: - decoder_mask = layers.make_decoder_mask( - decoder_target_tokens=decoder_target_tokens, - dtype=cfg.dtype, - decoder_causal_attention=decoder_causal_attention, - decoder_segment_ids=decoder_segment_ids, - ) - - embedding = layers.Embed( - num_embeddings=cfg.vocab_size, - features=cfg.emb_dim, - dtype=cfg.dtype, - attend_dtype=jnp.float32, # for logit training stability - embedding_init=nn.initializers.normal(stddev=1.0), - one_hot=True, - name='token_embedder', - ) - y = embedding(decoder_input_tokens.astype('int32')) - - y = nn.Dropout( - rate=cfg.dropout_rate, broadcast_dims=(-2,), name='input_dropout' - )(y, deterministic=deterministic) - y = y.astype(cfg.dtype) - - for lyr in range(cfg.num_layers): - # [batch, length, emb_dim] -> [batch, length, emb_dim] - y = DecoderLayer( - config=cfg, relative_embedding=rel_emb, name=f'layers_{lyr}' - )( - y, - decoder_mask=decoder_mask, - deterministic=deterministic, - decode=decode, - max_decode_length=max_decode_length, - prefill=prefill, - prefill_lengths=prefill_lengths, - ) - - y = layers.LayerNorm(dtype=cfg.dtype, name='decoder_norm')(y) - y = nn.Dropout( - rate=cfg.dropout_rate, broadcast_dims=(-2,), name='output_dropout' - )(y, deterministic=deterministic) - - # [batch, length, emb_dim] -> [batch, length, vocab_size] - if cfg.logits_via_embedding: - # Use the transpose of embedding matrix for the logit transform. - logits = embedding.attend(y) - # Correctly normalize pre-softmax logits for this shared case. - logits = logits / jnp.sqrt(y.shape[-1]) - else: - # Use a separate dense layer for the logit transform. - logits = layers.DenseGeneral( - cfg.vocab_size, - dtype=jnp.float32, # Use float32 for stabiliity. - kernel_axes=('embed', 'vocab'), - name='logits_dense', - )(y) - return logits - - -# TODO(hwchung): remove this after figuring out the name scope issue. -class DecoderWrapper(nn.Module): - """Thin wrapper for the outer "decoder/" name scope.""" - - config: TransformerConfig - - def setup(self): - self.decoder = Decoder(self.config, name='decoder') - - def __call__(self, *args, **kwargs): - return self.decoder(*args, **kwargs) diff --git a/t5x-main/t5x/examples/decoder_only/network_test.py b/t5x-main/t5x/examples/decoder_only/network_test.py deleted file mode 100644 index 15b4fd29fb8debfaec9af122015d9043234f4470..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/decoder_only/network_test.py +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for network.""" - -import os - -from absl import flags -from absl.testing import absltest -from absl.testing import parameterized -import jax -import numpy as np -from t5x import test_utils - -# Parse absl flags test_srcdir and test_tmpdir. -jax.config.parse_flags_with_absl() - -FLAGS = flags.FLAGS - - -if __name__ == '__main__': - absltest.main() diff --git a/t5x-main/t5x/examples/scalable_t5/README.md b/t5x-main/t5x/examples/scalable_t5/README.md deleted file mode 100644 index de0d1903ea62af12fde9785caf53939a7518d86f..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/scalable_t5/README.md +++ /dev/null @@ -1,47 +0,0 @@ -# Scalable T5 - -This directory is very similar to the vanilla T5X "T5" example, but demonstrates -a host of techniques needed to scale model training to giant models run on -large TPU or GPU cluster environments using XLA's SPMD capabilities. See the -notes for the main "t5" example for general details on setup and execution. - -## Intermediate variable annotations - -In larger models, with multi-axis model parallelism, it is typically necessary -to provide additional constraint annotations beyond those for the input and -output parameters for a function. We do this using a special version of the -`pjit` annotation function `with_sharding_constraint` that uses _logical_ axis -names instead of raw mesh axes. This allows us to avoid tightly coupling a -specific partitioning plan to the model code itself. Instead, we merely need -to annotate the axis names used in the model in a coherent scheme, and later -map these logical axes to the physical mesh axes using a small set of rules. -Example usage can be seen in `network.py`. - -## Scan over layers - -One challenge with giant models is the increasing amount of compilation time -required to handle extremely large layer stacks in XLA. At the size of a full -TPU pod this compile time cost can become quite extreme. To remedy this, -instead of handing the compiler a huge stack of unrolled layers, we can use -native XLA control flow constructs to simplify the computational graph given -from JAX. For giant models this can drop the compile time from hour(s) to -minutes, and even at base-scale can be roughly 5x faster. - -In this case, we want to use the [XLA While Op](xla-while) via JAX's -[scan](jax-scan) control flow construct to express the idea that we're looping -over identically-defined layers when using a deep transformer network. We do -this via a custom Flax version of scan called `scan_with_axes` that also handles -the parameter logical axis name metadata needed for partitioning. - -## Rematerialization / Checkpointing - -"Rematerialization" or "checkpointing" is a technique for trading off compute -time for lower peak memory utilization when performing reverse-mode automatic -differentiation. JAX offers several different default rematerialization -"policies" that dictate which kinds of intermediate values are preserved from -the forward-pass to the backwards-pass calculation, and which are discarded to -be recomputed anew in the backwards-pass. - - -[xla-while]: https://www.tensorflow.org/xla/operation_semantics#while -[jax-scan]: https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html diff --git a/t5x-main/t5x/examples/scalable_t5/__init__.py b/t5x-main/t5x/examples/scalable_t5/__init__.py deleted file mode 100644 index 0146a9b5441d4b6476a922553e6def9c1e3b598c..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/scalable_t5/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This empty file is needed for loading the gin files in this directory. diff --git a/t5x-main/t5x/examples/scalable_t5/layers.py b/t5x-main/t5x/examples/scalable_t5/layers.py deleted file mode 100644 index 91d9c98b49869e0e24be234d19f89bdf6e43d01c..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/scalable_t5/layers.py +++ /dev/null @@ -1,936 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Dense attention classes and mask/weighting functions.""" - -# pylint: disable=attribute-defined-outside-init,g-bare-generic - -import dataclasses -import functools -import operator -from typing import Any, Callable, Iterable, Optional, Sequence, Tuple, Union - -from flax import linen as nn -from flax.linen import partitioning as nn_partitioning -import jax -from jax import lax -from jax import random -import jax.numpy as jnp -import numpy as np - - -# from flax.linen.partitioning import param_with_axes, with_sharding_constraint -param_with_axes = nn_partitioning.param_with_axes -with_sharding_constraint = nn_partitioning.with_sharding_constraint - - -# Type annotations -Array = jnp.ndarray -DType = jnp.dtype -PRNGKey = jnp.ndarray -Shape = Sequence[int] -Activation = Callable[..., Array] -# Parameter initializers. -Initializer = Callable[[PRNGKey, Shape, DType], Array] -InitializerAxis = Union[int, Tuple[int, ...]] -NdInitializer = Callable[ - [PRNGKey, Shape, DType, InitializerAxis, InitializerAxis], Array -] - -default_embed_init = nn.initializers.variance_scaling( - 1.0, 'fan_in', 'normal', out_axis=0 -) - -variance_scaling = nn.initializers.variance_scaling - - -# ------------------------------------------------------------------------------ - - -def nd_dense_init(scale, mode, distribution): - """Initializer with in_axis, out_axis set at call time.""" - - def init_fn(key, shape, dtype, in_axis, out_axis): - fn = variance_scaling(scale, mode, distribution, in_axis, out_axis) - return fn(key, shape, dtype) - - return init_fn - - -def dot_product_attention( - query: Array, - key: Array, - value: Array, - bias: Optional[Array] = None, - dropout_rng: Optional[PRNGKey] = None, - dropout_rate: float = 0.0, - deterministic: bool = False, - dtype: DType = jnp.float32, - float32_logits: bool = False, -): - """Computes dot-product attention given query, key, and value. - - This is the core function for applying attention based on - https://arxiv.org/abs/1706.03762. It calculates the attention weights given - query and key and combines the values using the attention weights. - - Args: - query: queries for calculating attention with shape of `[batch, q_length, - num_heads, qk_depth_per_head]`. - key: keys for calculating attention with shape of `[batch, kv_length, - num_heads, qk_depth_per_head]`. - value: values to be used in attention with shape of `[batch, kv_length, - num_heads, v_depth_per_head]`. - bias: bias for the attention weights. This should be broadcastable to the - shape `[batch, num_heads, q_length, kv_length]` This can be used for - incorporating causal masks, padding masks, proximity bias, etc. - dropout_rng: JAX PRNGKey: to be used for dropout - dropout_rate: dropout rate - deterministic: bool, deterministic or not (to apply dropout) - dtype: the dtype of the computation (default: float32) - float32_logits: bool, if True then compute logits in float32 to avoid - numerical issues with bfloat16. - - Returns: - Output of shape `[batch, length, num_heads, v_depth_per_head]`. - """ - assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.' - assert ( - query.shape[:-3] == key.shape[:-3] == value.shape[:-3] - ), 'q, k, v batch dims must match.' - assert ( - query.shape[-2] == key.shape[-2] == value.shape[-2] - ), 'q, k, v num_heads must match.' - assert key.shape[-3] == value.shape[-3], 'k, v lengths must match.' - assert query.shape[-1] == key.shape[-1], 'q, k depths must match.' - - # Casting logits and softmax computation for float32 for model stability. - if float32_logits: - query = query.astype(jnp.float32) - key = key.astype(jnp.float32) - - # `attn_weights`: [batch, num_heads, q_length, kv_length] - attn_weights = jnp.einsum('bqhd,bkhd->bhqk', query, key) - - # Apply attention bias: masking, dropout, proximity bias, etc. - if bias is not None: - attn_weights = attn_weights + bias.astype(attn_weights.dtype) - - # Normalize the attention weights across `kv_length` dimension. - attn_weights = jax.nn.softmax(attn_weights).astype(dtype) - - # Apply attention dropout. - if not deterministic and dropout_rate > 0.0: - keep_prob = 1.0 - dropout_rate - # T5 broadcasts along the "length" dim, but unclear which one that - # corresponds to in positional dimensions here, assuming query dim. - dropout_shape = list(attn_weights.shape) - dropout_shape[-2] = 1 - keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape) - keep = jnp.broadcast_to(keep, attn_weights.shape) - multiplier = keep.astype(attn_weights.dtype) / jnp.asarray( - keep_prob, dtype=dtype - ) - attn_weights = attn_weights * multiplier - - # Take the linear combination of `value`. - return jnp.einsum('bhqk,bkhd->bqhd', attn_weights, value) - - -dynamic_vector_slice_in_dim = jax.vmap( - lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None) -) - - -class MultiHeadDotProductAttention(nn.Module): - """Multi-head dot-product attention. - - Attributes: - num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1]) - should be divisible by the number of heads. - head_dim: dimension of each head. - dtype: the dtype of the computation. - dropout_rate: dropout rate - kernel_init: initializer for the kernel of the Dense layers. - float32_logits: bool, if True then compute logits in float32 to avoid - numerical issues with bfloat16. - """ - - num_heads: int - head_dim: int - dtype: DType = jnp.float32 - dropout_rate: float = 0.0 - kernel_init: NdInitializer = nd_dense_init(1.0, 'fan_in', 'normal') - float32_logits: bool = False # computes logits in float32 for stability. - - @nn.compact - def __call__( - self, - inputs_q: Array, - inputs_kv: Array, - mask: Optional[Array] = None, - bias: Optional[Array] = None, - *, - decode: bool = False, - deterministic: bool = False, - ) -> Array: - """Applies multi-head dot product attention on the input data. - - Projects the inputs into multi-headed query, key, and value vectors, - applies dot-product attention and project the results to an output vector. - - There are two modes: decoding and non-decoding (e.g., training). The mode is - determined by `decode` argument. For decoding, this method is called twice, - first to initialize the cache and then for an actual decoding process. The - two calls are differentiated by the presence of 'cached_key' in the variable - dict. In the cache initialization stage, the cache variables are initialized - as zeros and will be filled in the subsequent decoding process. - - In the cache initialization call, `inputs_q` has a shape [batch, length, - q_features] and `inputs_kv`: [batch, length, kv_features]. During the - incremental decoding stage, query, key and value all have the shape [batch, - 1, qkv_features] corresponding to a single step. - - Args: - inputs_q: input queries of shape `[batch, q_length, q_features]`. - inputs_kv: key/values of shape `[batch, kv_length, kv_features]`. - mask: attention mask of shape `[batch, num_heads, q_length, kv_length]`. - bias: attention bias of shape `[batch, num_heads, q_length, kv_length]`. - decode: Whether to prepare and use an autoregressive cache. - deterministic: Disables dropout if set to True. - - Returns: - output of shape `[batch, length, q_features]`. - """ - projection = functools.partial( - DenseGeneral, - axis=-1, - features=(self.num_heads, self.head_dim), - kernel_axes=('embed', 'heads', 'kv'), - dtype=self.dtype, - ) - - # NOTE: T5 does not explicitly rescale the attention logits by - # 1/sqrt(depth_kq)! This is folded into the initializers of the - # linear transformations, which is equivalent under Adafactor. - depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype) - query_init = lambda *args: self.kernel_init(*args) / depth_scaling - - # Project inputs_q to multi-headed q/k/v - # dimensions are then [batch, length, num_heads, head_dim] - query = projection(kernel_init=query_init, name='query')(inputs_q) - key = projection(kernel_init=self.kernel_init, name='key')(inputs_kv) - value = projection(kernel_init=self.kernel_init, name='value')(inputs_kv) - - query = with_sharding_constraint(query, ('batch', 'length', 'heads', 'kv')) - key = with_sharding_constraint(key, ('batch', 'length', 'heads', 'kv')) - value = with_sharding_constraint(value, ('batch', 'length', 'heads', 'kv')) - - if decode: - # Detect if we're initializing by absence of existing cache data. - is_initialized = self.has_variable('cache', 'cached_key') - # The key and value have dimension [batch, length, num_heads, head_dim], - # but we cache them as [batch, num_heads, head_dim, length] as a TPU - # fusion optimization. This also enables the "scatter via one-hot - # broadcast" trick, which means we do a one-hot broadcast instead of a - # scatter/gather operations, resulting in a 3-4x speedup in practice. - swap_dims = lambda x: x[:-3] + tuple(x[i] for i in [-2, -1, -3]) - cached_key = self.variable( - 'cache', 'cached_key', jnp.zeros, swap_dims(key.shape), key.dtype - ) - cached_value = self.variable( - 'cache', - 'cached_value', - jnp.zeros, - swap_dims(value.shape), - value.dtype, - ) - cache_index = self.variable( - 'cache', 'cache_index', lambda: jnp.array(0, dtype=jnp.int32) - ) - if is_initialized: - batch, num_heads, head_dim, length = cached_key.value.shape - # During fast autoregressive decoding, we feed one position at a time, - # and cache the keys and values step by step. - # Sanity shape check of cached key against input query. - expected_shape = (batch, 1, num_heads, head_dim) - if expected_shape != query.shape: - raise ValueError( - 'Autoregressive cache shape error, ' - 'expected query shape %s instead got %s.' - % (expected_shape, query.shape) - ) - - # Create a OHE of the current index. NOTE: the index is increased below. - cur_index = cache_index.value - one_hot_indices = jax.nn.one_hot(cur_index, length, dtype=key.dtype) - # In order to update the key, value caches with the current key and - # value, we move the length axis to the back, similar to what we did for - # the cached ones above. - # Note these are currently the key and value of a single position, since - # we feed one position at a time. - one_token_key = jnp.moveaxis(key, -3, -1) - one_token_value = jnp.moveaxis(value, -3, -1) - # Update key, value caches with our new 1d spatial slices. - # We implement an efficient scatter into the cache via one-hot - # broadcast and addition. - key = cached_key.value + one_token_key * one_hot_indices - value = cached_value.value + one_token_value * one_hot_indices - cached_key.value = key - cached_value.value = value - cache_index.value = cache_index.value + 1 - # Move the keys and values back to their original shapes. - key = jnp.moveaxis(key, -1, -3) - value = jnp.moveaxis(value, -1, -3) - - # Causal mask for cached decoder self-attention: our single query - # position should only attend to those key positions that have already - # been generated and cached, not the remaining zero elements. - mask = combine_masks( - mask, - jnp.broadcast_to( - jnp.arange(length) <= cur_index, - # (1, 1, length) represent (head dim, query length, key length) - # query length is 1 because during decoding we deal with one - # index. - # The same mask is applied to all batch elements and heads. - (batch, 1, 1, length), - ), - ) - - # Grab the correct relative attention bias during decoding. This is - # only required during single step decoding. - if bias is not None: - # The bias is a full attention matrix, but during decoding we only - # have to take a slice of it. - # This is equivalent to bias[..., cur_index:cur_index+1, :]. - bias = dynamic_vector_slice_in_dim( - jnp.squeeze(bias, axis=0), jnp.reshape(cur_index, (-1)), 1, -2 - ) - - # Convert the boolean attention mask to an attention bias. - if mask is not None: - # attention mask in the form of attention bias - attention_bias = lax.select( - mask > 0, - jnp.full(mask.shape, 0.0).astype(self.dtype), - jnp.full(mask.shape, -1e10).astype(self.dtype), - ) - else: - attention_bias = None - - # Add provided bias term (e.g. relative position embedding). - if bias is not None: - attention_bias = combine_biases(attention_bias, bias) - - dropout_rng = None - if not deterministic and self.dropout_rate > 0.0: - dropout_rng = self.make_rng('dropout') - - # Apply attention. - x = dot_product_attention( - query, - key, - value, - bias=attention_bias, - dropout_rng=dropout_rng, - dropout_rate=self.dropout_rate, - deterministic=deterministic, - dtype=self.dtype, - float32_logits=self.float32_logits, - ) - - # Back to the original inputs dimensions. - out = DenseGeneral( - features=inputs_q.shape[-1], # output dim is set to the input dim. - axis=(-2, -1), - kernel_init=self.kernel_init, - kernel_axes=('heads', 'kv', 'embed'), - dtype=self.dtype, - name='out', - )(x) - return out - - -def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int, ...]: - # A tuple by convention. len(axes_tuple) then also gives the rank efficiently. - return tuple([ax if ax >= 0 else ndim + ax for ax in axes]) - - -def _canonicalize_tuple(x): - if isinstance(x, Iterable): - return tuple(x) - else: - return (x,) - - -# ------------------------------------------------------------------------------ -# DenseGeneral for attention layers. -# ------------------------------------------------------------------------------ -class DenseGeneral(nn.Module): - """A linear transformation (without bias) with flexible axes. - - Attributes: - features: tuple with numbers of output features. - axis: tuple with axes to apply the transformation on. - dtype: the dtype of the computation (default: float32). - kernel_init: initializer function for the weight matrix. - """ - - features: Union[Iterable[int], int] - axis: Union[Iterable[int], int] = -1 - dtype: DType = jnp.float32 - kernel_init: NdInitializer = nd_dense_init(1.0, 'fan_in', 'truncated_normal') - kernel_axes: Tuple[str, ...] = () - - @nn.compact - def __call__(self, inputs: Array) -> Array: - """Applies a linear transformation to the inputs along multiple dimensions. - - Args: - inputs: The nd-array to be transformed. - - Returns: - The transformed input. - """ - features = _canonicalize_tuple(self.features) - axis = _canonicalize_tuple(self.axis) - - inputs = jnp.asarray(inputs, self.dtype) - axis = _normalize_axes(axis, inputs.ndim) - - kernel_shape = tuple([inputs.shape[ax] for ax in axis]) + features - kernel_in_axis = np.arange(len(axis)) - kernel_out_axis = np.arange(len(axis), len(axis) + len(features)) - kernel = param_with_axes( - 'kernel', - self.kernel_init, - kernel_shape, - jnp.float32, - kernel_in_axis, - kernel_out_axis, - axes=self.kernel_axes, - ) - kernel = jnp.asarray(kernel, self.dtype) - - contract_ind = tuple(range(0, len(axis))) - return lax.dot_general(inputs, kernel, ((axis, contract_ind), ((), ()))) - - -def _convert_to_activation_function( - fn_or_string: Union[str, Callable] -) -> Callable: - """Convert a string to an activation function.""" - if fn_or_string == 'linear': - return lambda x: x - elif isinstance(fn_or_string, str): - return getattr(nn, fn_or_string) - elif callable(fn_or_string): - return fn_or_string - else: - raise ValueError( - "don't know how to convert %s to an activation function" - % (fn_or_string,) - ) - - -class MlpBlock(nn.Module): - """Transformer MLP / feed-forward block. - - Attributes: - intermediate_dim: Shared dimension of hidden layers. - activations: Type of activations for each layer. Each element is either - 'linear', a string function name in flax.linen, or a function. - kernel_init: Kernel function, passed to the dense layers. - deterministic: Whether the dropout layers should be deterministic. - intermediate_dropout_rate: Dropout rate used after the intermediate layers. - dtype: Type for the dense layer. - """ - - intermediate_dim: int = 2048 - activations: Sequence[Union[str, Callable]] = ('relu',) - kernel_init: NdInitializer = nd_dense_init(1.0, 'fan_in', 'truncated_normal') - intermediate_dropout_rate: float = 0.1 - dtype: Any = jnp.float32 - - @nn.compact - def __call__(self, inputs, decode: bool = False, deterministic: bool = False): - """Applies Transformer MlpBlock module.""" - # Iterate over specified MLP input activation functions. - # e.g. ('relu',) or ('gelu', 'linear') for gated-gelu. - activations = [] - for idx, act_fn in enumerate(self.activations): - dense_name = 'wi' if len(self.activations) == 1 else f'wi_{idx}' - x = DenseGeneral( - self.intermediate_dim, - dtype=self.dtype, - kernel_init=self.kernel_init, - kernel_axes=('embed', 'mlp'), - name=dense_name, - )(inputs) - x = _convert_to_activation_function(act_fn)(x) - activations.append(x) - - # Take elementwise product of above intermediate activations. - x = functools.reduce(operator.mul, activations) - # Apply dropout and final dense output projection. - x = nn.Dropout(rate=self.intermediate_dropout_rate, broadcast_dims=(-2,))( - x, deterministic=deterministic - ) # Broadcast along length. - x = with_sharding_constraint(x, ('batch', 'length', 'mlp')) - output = DenseGeneral( - inputs.shape[-1], - dtype=self.dtype, - kernel_init=self.kernel_init, - kernel_axes=('mlp', 'embed'), - name='wo', - )(x) - return output - - -class Embed(nn.Module): - """A parameterized function from integers [0, n) to d-dimensional vectors. - - Attributes: - num_embeddings: number of embeddings. - features: number of feature dimensions for each embedding. - dtype: the dtype of the embedding vectors (default: float32). - embedding_init: embedding initializer. - one_hot: performs the gather with a one-hot contraction rather than a true - gather. This is currently needed for SPMD partitioning. - """ - - num_embeddings: int - features: int - cast_input_dtype: Optional[DType] = None - dtype: DType = jnp.float32 - attend_dtype: Optional[DType] = None - embedding_init: Initializer = default_embed_init - one_hot: bool = False - embedding: Array = dataclasses.field(init=False) - - def setup(self): - self.embedding = param_with_axes( - 'embedding', - self.embedding_init, - (self.num_embeddings, self.features), - jnp.float32, - axes=('vocab', 'embed'), - ) - - def __call__(self, inputs: Array) -> Array: - """Embeds the inputs along the last dimension. - - Args: - inputs: input data, all dimensions are considered batch dimensions. - - Returns: - Output which is embedded input data. The output shape follows the input, - with an additional `features` dimension appended. - """ - if self.cast_input_dtype: - inputs = inputs.astype(self.cast_input_dtype) - if not jnp.issubdtype(inputs.dtype, jnp.integer): - raise ValueError('Input type must be an integer or unsigned integer.') - if self.one_hot: - iota = lax.iota(jnp.int32, self.num_embeddings) - one_hot = jnp.array(inputs[..., jnp.newaxis] == iota, dtype=self.dtype) - output = jnp.dot(one_hot, jnp.asarray(self.embedding, self.dtype)) - else: - output = jnp.asarray(self.embedding, self.dtype)[inputs] - output = with_sharding_constraint(output, ('batch', 'length', 'embed')) - return output - - def attend(self, query: Array) -> Array: - """Attend over the embedding using a query array. - - Args: - query: array with last dimension equal the feature depth `features` of the - embedding. - - Returns: - An array with final dim `num_embeddings` corresponding to the batched - inner-product of the array of query vectors against each embedding. - Commonly used for weight-sharing between embeddings and logit transform - in NLP models. - """ - dtype = self.attend_dtype if self.attend_dtype is not None else self.dtype - return jnp.dot(query, jnp.asarray(self.embedding, dtype).T) - - -class RelativePositionBiases(nn.Module): - """Adds T5-style relative positional embeddings to the attention logits. - - Attributes: - num_buckets: Number of buckets to bucket distances between key and query - positions into. - max_distance: Maximum distance before everything is lumped into the last - distance bucket. - num_heads: Number of heads in the attention layer. Each head will get a - different relative position weighting. - dtype: Type of arrays through this module. - embedding_init: initializer for relative embedding table. - """ - - num_buckets: int - max_distance: int - num_heads: int - dtype: Any - embedding_init: Callable[..., Array] = nn.linear.default_embed_init - - @staticmethod - def _relative_position_bucket( - relative_position, bidirectional=True, num_buckets=32, max_distance=128 - ): - """Translate relative position to a bucket number for relative attention. - - The relative position is defined as memory_position - query_position, i.e. - the distance in tokens from the attending position to the attended-to - position. If bidirectional=False, then positive relative positions are - invalid. - We use smaller buckets for small absolute relative_position and larger - buckets for larger absolute relative_positions. All relative - positions >=max_distance map to the same bucket. All relative - positions <=-max_distance map to the same bucket. This should allow for - more graceful generalization to longer sequences than the model has been - trained on. - - Args: - relative_position: an int32 array - bidirectional: a boolean - whether the attention is bidirectional - num_buckets: an integer - max_distance: an integer - - Returns: - a Tensor with the same shape as relative_position, containing int32 - values in the range [0, num_buckets) - """ - ret = 0 - n = -relative_position - if bidirectional: - num_buckets //= 2 - ret += (n < 0).astype(np.int32) * num_buckets - n = np.abs(n) - else: - n = np.maximum(n, 0) - # now n is in the range [0, inf) - max_exact = num_buckets // 2 - is_small = n < max_exact - val_if_large = max_exact + ( - np.log(n.astype(np.float32) / max_exact + np.finfo(np.float32).eps) - / np.log(max_distance / max_exact) - * (num_buckets - max_exact) - ).astype(np.int32) - val_if_large = np.minimum(val_if_large, num_buckets - 1) - ret += np.where(is_small, n, val_if_large) - return ret - - @nn.compact - def __call__(self, qlen, klen, bidirectional=True): - """Produce relative position embedding attention biases. - - Args: - qlen: attention query length. - klen: attention key length. - bidirectional: whether to allow positive memory-query relative position - embeddings. - - Returns: - output: `(1, len, q_len, k_len)` attention bias - """ - # TODO(levskaya): should we be computing this w. numpy as a program - # constant? - context_position = np.arange(qlen, dtype=jnp.int32)[:, None] - memory_position = np.arange(klen, dtype=jnp.int32)[None, :] - relative_position = memory_position - context_position # shape (qlen, klen) - rp_bucket = self._relative_position_bucket( - relative_position, - bidirectional=bidirectional, - num_buckets=self.num_buckets, - max_distance=self.max_distance, - ) - relative_attention_bias = param_with_axes( - 'rel_embedding', - self.embedding_init, - (self.num_heads, self.num_buckets), - jnp.float32, - axes=('heads', 'relpos_buckets'), - ) - - relative_attention_bias = jnp.asarray(relative_attention_bias, self.dtype) - # Instead of using a slow gather, we create a leading-dimension one-hot - # array from rp_bucket and use it to perform the gather-equivalent via a - # contraction, i.e.: - # (num_head, num_buckets) x (num_buckets one-hot, qlen, klen). - # This is equivalent to relative_attention_bias[:, rp_bucket] - bcast_iota = lax.broadcasted_iota(jnp.int32, (self.num_buckets, 1, 1), 0) - rp_bucket_one_hot = jnp.array( - rp_bucket[jnp.newaxis, ...] == bcast_iota, dtype=self.dtype - ) - # --> shape (qlen, klen, num_heads) - values = lax.dot_general( - relative_attention_bias, - rp_bucket_one_hot, - (((1,), (0,)), ((), ())), # rhs, lhs contracting dims - ) # no batched dims - # Add a singleton batch dimension. - # --> shape (1, num_heads, qlen, klen) - return values[jnp.newaxis, ...] - - -# ------------------------------------------------------------------------------ -# T5 Layernorm - no subtraction of mean or bias. -# ------------------------------------------------------------------------------ -class LayerNorm(nn.Module): - """T5 Layer normalization operating on the last axis of the input data.""" - - epsilon: float = 1e-6 - dtype: Any = jnp.float32 - scale_init: Initializer = nn.initializers.ones - - @nn.compact - def __call__(self, x: jnp.ndarray) -> jnp.ndarray: - """Applies layer normalization on the input.""" - x = jnp.asarray(x, jnp.float32) - features = x.shape[-1] - mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True) - y = jnp.asarray(x * lax.rsqrt(mean2 + self.epsilon), self.dtype) - scale = param_with_axes( - 'scale', self.scale_init, (features,), jnp.float32, axes=('embed',) - ) - - scale = jnp.asarray(scale, self.dtype) - return y * scale - - -# ------------------------------------------------------------------------------ -# Mask-making utility functions. -# ------------------------------------------------------------------------------ -def make_attention_mask( - query_input: Array, - key_input: Array, - pairwise_fn: Callable = jnp.multiply, - extra_batch_dims: int = 0, - dtype: DType = jnp.float32, -) -> Array: - """Mask-making helper for attention weights. - - In case of 1d inputs (i.e., `[batch, len_q]`, `[batch, len_kv]`, the - attention weights will be `[batch, heads, len_q, len_kv]` and this - function will produce `[batch, 1, len_q, len_kv]`. - - Args: - query_input: a batched, flat input of query_length size - key_input: a batched, flat input of key_length size - pairwise_fn: broadcasting elementwise comparison function - extra_batch_dims: number of extra batch dims to add singleton axes for, none - by default - dtype: mask return dtype - - Returns: - A `[batch, 1, len_q, len_kv]` shaped mask for 1d attention. - """ - # [batch, len_q, len_kv] - mask = pairwise_fn( - # [batch, len_q] -> [batch, len_q, 1] - jnp.expand_dims(query_input, axis=-1), - # [batch, len_q] -> [batch, 1, len_kv] - jnp.expand_dims(key_input, axis=-2), - ) - - # [batch, 1, len_q, len_kv]. This creates the head dim. - mask = jnp.expand_dims(mask, axis=-3) - mask = jnp.expand_dims(mask, axis=tuple(range(extra_batch_dims))) - return mask.astype(dtype) - - -def make_causal_mask( - x: Array, extra_batch_dims: int = 0, dtype: DType = jnp.float32 -) -> Array: - """Make a causal mask for self-attention. - - In case of 1d inputs (i.e., `[batch, len]`, the self-attention weights - will be `[batch, heads, len, len]` and this function will produce a - causal mask of shape `[batch, 1, len, len]`. - - Note that a causal mask does not depend on the values of x; it only depends on - the shape. If x has padding elements, they will not be treated in a special - manner. - - Args: - x: input array of shape `[batch, len]` - extra_batch_dims: number of batch dims to add singleton axes for, none by - default - dtype: mask return dtype - - Returns: - A `[batch, 1, len, len]` shaped causal mask for 1d attention. - """ - idxs = jnp.broadcast_to(jnp.arange(x.shape[-1], dtype=jnp.int32), x.shape) - return make_attention_mask( - idxs, - idxs, - jnp.greater_equal, - extra_batch_dims=extra_batch_dims, - dtype=dtype, - ) - - -def combine_masks(*masks: Optional[Array], dtype: DType = jnp.float32): - """Combine attention masks. - - Args: - *masks: set of attention mask arguments to combine, some can be None. - dtype: final mask dtype - - Returns: - Combined mask, reduced by logical and, returns None if no masks given. - """ - masks = [m for m in masks if m is not None] - if not masks: - return None - assert all( - map(lambda x: x.ndim == masks[0].ndim, masks) - ), f'masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}' - mask, *other_masks = masks - for other_mask in other_masks: - mask = jnp.logical_and(mask, other_mask) - return mask.astype(dtype) - - -def combine_biases(*masks: Optional[Array]): - """Combine attention biases. - - Args: - *masks: set of attention bias arguments to combine, some can be None. - - Returns: - Combined mask, reduced by summation, returns None if no masks given. - """ - masks = [m for m in masks if m is not None] - if not masks: - return None - assert all( - map(lambda x: x.ndim == masks[0].ndim, masks) - ), f'masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}' - mask, *other_masks = masks - for other_mask in other_masks: - mask = mask + other_mask - return mask - - -def make_decoder_mask( - decoder_target_tokens: Array, - dtype: DType, - decoder_causal_attention: Optional[Array] = None, - decoder_segment_ids: Optional[Array] = None, -) -> Array: - """Compute the self-attention mask for a decoder. - - Decoder mask is formed by combining a causal mask, a padding mask and an - optional packing mask. If decoder_causal_attention is passed, it makes the - masking non-causal for positions that have value of 1. - - A prefix LM is applied to a dataset which has a notion of "inputs" and - "targets", e.g., a machine translation task. The inputs and targets are - concatenated to form a new target. `decoder_target_tokens` is the concatenated - decoder output tokens. - - The "inputs" portion of the concatenated sequence can attend to other "inputs" - tokens even for those at a later time steps. In order to control this - behavior, `decoder_causal_attention` is necessary. This is a binary mask with - a value of 1 indicating that the position belonged to "inputs" portion of the - original dataset. - - Example: - - Suppose we have a dataset with two examples. - - ds = [{"inputs": [6, 7], "targets": [8]}, - {"inputs": [3, 4], "targets": [5]}] - - After the data preprocessing with packing, the two examples are packed into - one example with the following three fields (some fields are skipped for - simplicity). - - decoder_target_tokens = [[6, 7, 8, 3, 4, 5, 0]] - decoder_segment_ids = [[1, 1, 1, 2, 2, 2, 0]] - decoder_causal_attention = [[1, 1, 0, 1, 1, 0, 0]] - - where each array has [batch, length] shape with batch size being 1. Then, - this function computes the following mask. - - mask = [[[[1, 1, 0, 0, 0, 0, 0], - [1, 1, 0, 0, 0, 0, 0], - [1, 1, 1, 0, 0, 0, 0], - [0, 0, 0, 1, 1, 0, 0], - [0, 0, 0, 1, 1, 0, 0], - [0, 0, 0, 1, 1, 1, 0], - [0, 0, 0, 0, 0, 0, 0]]]] - - mask[b, 1, :, :] represents the mask for the example `b` in the batch. - Because mask is for a self-attention layer, the mask's shape is a square of - shape [query length, key length]. - - mask[b, 1, i, j] = 1 means that the query token at position i can attend to - the key token at position j. - - Args: - decoder_target_tokens: decoder output tokens. [batch, length] - dtype: dtype of the output mask. - decoder_causal_attention: a binary mask indicating which position should - only attend to earlier positions in the sequence. Others will attend - bidirectionally. [batch, length] - decoder_segment_ids: decoder segmentation info for packed examples. [batch, - length] - - Returns: - the combined decoder mask. - """ - masks = [] - # The same mask is applied to all attention heads. So the head dimension is 1, - # i.e., the mask will be broadcast along the heads dim. - # [batch, 1, length, length] - causal_mask = make_causal_mask(decoder_target_tokens, dtype=dtype) - - # Positions with value 1 in `decoder_causal_attneition` can attend - # bidirectionally. - if decoder_causal_attention is not None: - # [batch, 1, length, length] - inputs_mask = make_attention_mask( - decoder_causal_attention, - decoder_causal_attention, - jnp.logical_and, - dtype=dtype, - ) - masks.append(jnp.logical_or(causal_mask, inputs_mask).astype(dtype)) - else: - masks.append(causal_mask) - - # Padding mask. - masks.append( - make_attention_mask( - decoder_target_tokens > 0, decoder_target_tokens > 0, dtype=dtype - ) - ) - - # Packing mask - if decoder_segment_ids is not None: - masks.append( - make_attention_mask( - decoder_segment_ids, decoder_segment_ids, jnp.equal, dtype=dtype - ) - ) - - return combine_masks(*masks, dtype=dtype) # pytype: disable=bad-return-type # jax-ndarray diff --git a/t5x-main/t5x/examples/scalable_t5/layers_test.py b/t5x-main/t5x/examples/scalable_t5/layers_test.py deleted file mode 100644 index bdfbd24d3f360f4c575d2da613b86eaac8cbdb6d..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/scalable_t5/layers_test.py +++ /dev/null @@ -1,701 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for attention classes.""" - -import dataclasses -from typing import Optional -from unittest import mock - -from absl.testing import absltest -from absl.testing import parameterized -from flax import linen as nn -from flax.core import freeze -from flax.linen import partitioning as nn_partitioning -import jax -from jax import random -from jax.nn import initializers -import jax.numpy as jnp -import numpy as np -from t5x.examples.scalable_t5 import layers - -# Parse absl flags test_srcdir and test_tmpdir. -jax.config.parse_flags_with_absl() - -Array = jnp.ndarray -AxisMetadata = nn_partitioning.AxisMetadata # pylint: disable=invalid-name - - -class SelfAttention(layers.MultiHeadDotProductAttention): - """Self-attention special case of multi-head dot-product attention.""" - - @nn.compact - def __call__( - self, - inputs_q: Array, - mask: Optional[Array] = None, - bias: Optional[Array] = None, - deterministic: bool = False, - ): - return super().__call__( - inputs_q, inputs_q, mask, bias, deterministic=deterministic - ) - - -@dataclasses.dataclass(frozen=True) -class SelfAttentionArgs: - num_heads: int = 1 - batch_size: int = 2 - # qkv_features: int = 3 - head_dim: int = 3 - # out_features: int = 4 - q_len: int = 5 - features: int = 6 - dropout_rate: float = 0.1 - deterministic: bool = False - decode: bool = False - float32_logits: bool = False - - def __post_init__(self): - # If we are doing decoding, the query length should be 1, because are doing - # autoregressive decoding where we feed one position at a time. - assert not self.decode or self.q_len == 1 - - def init_args(self): - return dict( - num_heads=self.num_heads, - head_dim=self.head_dim, - dropout_rate=self.dropout_rate, - float32_logits=self.float32_logits, - ) - - def apply_args(self): - inputs_q = jnp.ones((self.batch_size, self.q_len, self.features)) - mask = jnp.ones((self.batch_size, self.num_heads, self.q_len, self.q_len)) - bias = jnp.ones((self.batch_size, self.num_heads, self.q_len, self.q_len)) - return { - 'inputs_q': inputs_q, - 'mask': mask, - 'bias': bias, - 'deterministic': self.deterministic, - } - - -class AttentionTest(parameterized.TestCase): - - def test_dot_product_attention_shape(self): - # This test only checks for shape but tries to make sure all code paths are - # reached. - dropout_rng = random.PRNGKey(0) - batch_size, num_heads, q_len, kv_len, qk_depth, v_depth = 1, 2, 3, 4, 5, 6 - - query = jnp.ones((batch_size, q_len, num_heads, qk_depth)) - key = jnp.ones((batch_size, kv_len, num_heads, qk_depth)) - value = jnp.ones((batch_size, kv_len, num_heads, v_depth)) - bias = jnp.ones((batch_size, num_heads, q_len, kv_len)) - - args = dict( - query=query, - key=key, - value=value, - bias=bias, - dropout_rng=dropout_rng, - dropout_rate=0.5, - deterministic=False, - ) - - output = layers.dot_product_attention(**args) - self.assertEqual(output.shape, (batch_size, q_len, num_heads, v_depth)) - - def test_make_attention_mask_multiply_pairwise_fn(self): - decoder_target_tokens = jnp.array([[7, 0, 0], [8, 5, 0]]) - attention_mask = layers.make_attention_mask( - decoder_target_tokens > 0, decoder_target_tokens > 0, dtype=jnp.int32 - ) - expected0 = jnp.array([[1, 0, 0], [0, 0, 0], [0, 0, 0]]) - expected1 = jnp.array([[1, 1, 0], [1, 1, 0], [0, 0, 0]]) - self.assertEqual(attention_mask.shape, (2, 1, 3, 3)) - np.testing.assert_array_equal(attention_mask[0, 0], expected0) - np.testing.assert_array_equal(attention_mask[1, 0], expected1) - - def test_make_attention_mask_equal_pairwise_fn(self): - segment_ids = jnp.array([[1, 1, 2, 2, 2, 0], [1, 1, 1, 2, 0, 0]]) - attention_mask = layers.make_attention_mask( - segment_ids, segment_ids, pairwise_fn=jnp.equal, dtype=jnp.int32 - ) - # Padding is not treated in a special way. So they need to be zeroed out - # separately. - expected0 = jnp.array([ - [1, 1, 0, 0, 0, 0], - [1, 1, 0, 0, 0, 0], - [0, 0, 1, 1, 1, 0], - [0, 0, 1, 1, 1, 0], - [0, 0, 1, 1, 1, 0], - [0, 0, 0, 0, 0, 1], - ]) - expected1 = jnp.array([ - [1, 1, 1, 0, 0, 0], - [1, 1, 1, 0, 0, 0], - [1, 1, 1, 0, 0, 0], - [0, 0, 0, 1, 0, 0], - [0, 0, 0, 0, 1, 1], - [0, 0, 0, 0, 1, 1], - ]) - self.assertEqual(attention_mask.shape, (2, 1, 6, 6)) - np.testing.assert_array_equal(attention_mask[0, 0], expected0) - np.testing.assert_array_equal(attention_mask[1, 0], expected1) - - def test_make_causal_mask_with_padding(self): - x = jnp.array([[7, 0, 0], [8, 5, 0]]) - y = layers.make_causal_mask(x) - self.assertEqual(y.shape, (2, 1, 3, 3)) - # Padding is not treated in a special way. So they need to be zeroed out - # separately. - expected_y = jnp.array( - [[[1.0, 0.0, 0.0], [1.0, 1.0, 0.0], [1.0, 1.0, 1.0]]], jnp.float32 - ) - np.testing.assert_allclose(y[0], expected_y) - np.testing.assert_allclose(y[1], expected_y) - - def test_make_causal_mask_extra_batch_dims(self): - x = jnp.ones((3, 3, 5)) - y = layers.make_causal_mask(x, extra_batch_dims=2) - self.assertEqual(y.shape, (1, 1, 3, 3, 1, 5, 5)) - - def test_make_causal_mask(self): - x = jnp.ones((1, 3)) - y = layers.make_causal_mask(x) - self.assertEqual(y.shape, (1, 1, 3, 3)) - expected_y = jnp.array( - [[[[1.0, 0.0, 0.0], [1.0, 1.0, 0.0], [1.0, 1.0, 1.0]]]], jnp.float32 - ) - np.testing.assert_allclose(y, expected_y) - - def test_combine_masks(self): - masks = [ - jnp.array([0, 1, 0, 1], jnp.float32), - None, - jnp.array([1, 1, 1, 1], jnp.float32), - jnp.array([1, 1, 1, 0], jnp.float32), - ] - y = layers.combine_masks(*masks) - np.testing.assert_allclose(y, jnp.array([0, 1, 0, 0], jnp.float32)) - - def test_combine_biases(self): - masks = [ - jnp.array([0, 1, 0, 1], jnp.float32), - None, - jnp.array([0, 1, 1, 1], jnp.float32), - jnp.array([0, 1, 1, 0], jnp.float32), - ] - y = layers.combine_biases(*masks) - np.testing.assert_allclose(y, jnp.array([0, 3, 2, 2], jnp.float32)) - - def test_make_decoder_mask_lm_unpacked(self): - decoder_target_tokens = jnp.array([6, 7, 3, 0]) - mask = layers.make_decoder_mask( - decoder_target_tokens=decoder_target_tokens, dtype=jnp.float32 - ) - expected_mask = jnp.array( - [[[1, 0, 0, 0], [1, 1, 0, 0], [1, 1, 1, 0], [0, 0, 0, 0]]] - ) - np.testing.assert_array_equal(mask, expected_mask) - - def test_make_decoder_mask_lm_packed(self): - decoder_target_tokens = jnp.array([[6, 7, 3, 4, 5, 0]]) - decoder_segment_ids = jnp.array([[1, 1, 1, 2, 2, 0]]) - mask = layers.make_decoder_mask( - decoder_target_tokens=decoder_target_tokens, - dtype=jnp.float32, - decoder_segment_ids=decoder_segment_ids, - ) - expected_mask = jnp.array([[[ - [1, 0, 0, 0, 0, 0], - [1, 1, 0, 0, 0, 0], - [1, 1, 1, 0, 0, 0], - [0, 0, 0, 1, 0, 0], - [0, 0, 0, 1, 1, 0], - [0, 0, 0, 0, 0, 0], - ]]]) - np.testing.assert_array_equal(mask, expected_mask) - - def test_make_decoder_mask_prefix_lm_unpacked(self): - decoder_target_tokens = jnp.array([[5, 6, 7, 3, 4, 0]]) - decoder_causal_attention = jnp.array([[1, 1, 1, 0, 0, 0]]) - mask = layers.make_decoder_mask( - decoder_target_tokens=decoder_target_tokens, - dtype=jnp.float32, - decoder_causal_attention=decoder_causal_attention, - ) - expected_mask = jnp.array( - [[[ - [1, 1, 1, 0, 0, 0], - [1, 1, 1, 0, 0, 0], - [1, 1, 1, 0, 0, 0], - [1, 1, 1, 1, 0, 0], - [1, 1, 1, 1, 1, 0], - [0, 0, 0, 0, 0, 0], - ]]], - dtype=jnp.float32, - ) - np.testing.assert_array_equal(mask, expected_mask) - - def test_make_decoder_mask_prefix_lm_packed(self): - decoder_target_tokens = jnp.array([[5, 6, 7, 8, 3, 4, 0]]) - decoder_segment_ids = jnp.array([[1, 1, 1, 2, 2, 2, 0]]) - decoder_causal_attention = jnp.array([[1, 1, 0, 1, 1, 0, 0]]) - mask = layers.make_decoder_mask( - decoder_target_tokens=decoder_target_tokens, - dtype=jnp.float32, - decoder_causal_attention=decoder_causal_attention, - decoder_segment_ids=decoder_segment_ids, - ) - expected_mask = jnp.array([[[ - [1, 1, 0, 0, 0, 0, 0], - [1, 1, 0, 0, 0, 0, 0], - [1, 1, 1, 0, 0, 0, 0], - [0, 0, 0, 1, 1, 0, 0], - [0, 0, 0, 1, 1, 0, 0], - [0, 0, 0, 1, 1, 1, 0], - [0, 0, 0, 0, 0, 0, 0], - ]]]) - np.testing.assert_array_equal(mask, expected_mask) - - def test_make_decoder_mask_prefix_lm_unpacked_multiple_elements(self): - decoder_target_tokens = jnp.array([[6, 7, 3, 0], [4, 5, 0, 0]]) - decoder_causal_attention = jnp.array([[1, 1, 0, 0], [1, 0, 0, 0]]) - mask = layers.make_decoder_mask( - decoder_target_tokens=decoder_target_tokens, - dtype=jnp.float32, - decoder_causal_attention=decoder_causal_attention, - ) - expected_mask0 = jnp.array( - [[1, 1, 0, 0], [1, 1, 0, 0], [1, 1, 1, 0], [0, 0, 0, 0]] - ) - expected_mask1 = jnp.array( - [[1, 0, 0, 0], [1, 1, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]] - ) - self.assertEqual(mask.shape, (2, 1, 4, 4)) - np.testing.assert_array_equal(mask[0, 0], expected_mask0) - np.testing.assert_array_equal(mask[1, 0], expected_mask1) - - def test_make_decoder_mask_composite_causal_attention(self): - decoder_target_tokens = jnp.array([[6, 7, 3, 4, 8, 9, 0]]) - decoder_causal_attention = jnp.array([[1, 1, 0, 0, 1, 1, 0]]) - mask = layers.make_decoder_mask( - decoder_target_tokens=decoder_target_tokens, - dtype=jnp.float32, - decoder_causal_attention=decoder_causal_attention, - ) - expected_mask0 = jnp.array([ - [1, 1, 0, 0, 1, 1, 0], - [1, 1, 0, 0, 1, 1, 0], - [1, 1, 1, 0, 0, 0, 0], - [1, 1, 1, 1, 0, 0, 0], - [1, 1, 1, 1, 1, 1, 0], - [1, 1, 1, 1, 1, 1, 0], - [0, 0, 0, 0, 0, 0, 0], - ]) - - self.assertEqual(mask.shape, (1, 1, 7, 7)) - np.testing.assert_array_equal(mask[0, 0], expected_mask0) - - def test_make_decoder_mask_composite_causal_attention_packed(self): - decoder_target_tokens = jnp.array([[6, 7, 3, 4, 8, 9, 2, 3, 4]]) - decoder_segment_ids = jnp.array([[1, 1, 1, 1, 1, 1, 2, 2, 2]]) - decoder_causal_attention = jnp.array([[1, 1, 0, 0, 1, 1, 1, 1, 0]]) - mask = layers.make_decoder_mask( - decoder_target_tokens=decoder_target_tokens, - dtype=jnp.float32, - decoder_causal_attention=decoder_causal_attention, - decoder_segment_ids=decoder_segment_ids, - ) - expected_mask0 = jnp.array([ - [1, 1, 0, 0, 1, 1, 0, 0, 0], - [1, 1, 0, 0, 1, 1, 0, 0, 0], - [1, 1, 1, 0, 0, 0, 0, 0, 0], - [1, 1, 1, 1, 0, 0, 0, 0, 0], - [1, 1, 1, 1, 1, 1, 0, 0, 0], - [1, 1, 1, 1, 1, 1, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 1, 1, 0], - [0, 0, 0, 0, 0, 0, 1, 1, 0], - [0, 0, 0, 0, 0, 0, 1, 1, 1], - ]) - - self.assertEqual(mask.shape, (1, 1, 9, 9)) - np.testing.assert_array_equal(mask[0, 0], expected_mask0) - - @parameterized.parameters({'f': 20}, {'f': 22}) - def test_multihead_dot_product_attention(self, f): - # b: batch, f: emb_dim, q: q_len, k: kv_len, h: num_head, d: head_dim - b, q, h, d, k = 2, 3, 4, 5, 6 - - base_args = SelfAttentionArgs(num_heads=h, head_dim=d, dropout_rate=0) - args = base_args.init_args() - - np.random.seed(0) - inputs_q = np.random.randn(b, q, f) - inputs_kv = np.random.randn(b, k, f) - - # Projection: [b, q, f] -> [b, q, h, d] - # So the kernels have to be [f, h, d] - query_kernel = np.random.randn(f, h, d) - key_kernel = np.random.randn(f, h, d) - value_kernel = np.random.randn(f, h, d) - # `out` calculation: [b, q, h, d] -> [b, q, f] - # So kernel has to be [h, d, f] - out_kernel = np.random.randn(h, d, f) - - params = { - 'query': {'kernel': query_kernel}, - 'key': {'kernel': key_kernel}, - 'value': {'kernel': value_kernel}, - 'out': {'kernel': out_kernel}, - } - y = layers.MultiHeadDotProductAttention(**args).apply( - {'params': freeze(params)}, inputs_q, inputs_kv - ) - - query = np.einsum('bqf,fhd->bqhd', inputs_q, query_kernel) - key = np.einsum('bkf,fhd->bkhd', inputs_kv, key_kernel) - value = np.einsum('bkf,fhd->bkhd', inputs_kv, value_kernel) - logits = np.einsum('bqhd,bkhd->bhqk', query, key) - weights = nn.softmax(logits, axis=-1) - combined_value = np.einsum('bhqk,bkhd->bqhd', weights, value) - y_expected = np.einsum('bqhd,hdf->bqf', combined_value, out_kernel) - np.testing.assert_allclose(y, y_expected, rtol=1e-5, atol=1e-5) - - def test_multihead_dot_product_attention_caching(self): - # b: batch, f: qkv_features, k: kv_len, h: num_head, d: head_dim - b, h, d, k = 2, 3, 4, 5 - f = h * d - - base_args = SelfAttentionArgs(num_heads=h, head_dim=d, dropout_rate=0) - args = base_args.init_args() - - cache = { - 'cached_key': np.zeros((b, h, d, k)), - 'cached_value': np.zeros((b, h, d, k)), - 'cache_index': np.array(0), - } - inputs_q = np.random.randn(b, 1, f) - inputs_kv = np.random.randn(b, 1, f) - - # Mock dense general such that q, k, v projections are replaced by simple - # reshaping. - def mock_dense_general(self, x, **kwargs): # pylint: disable=unused-argument - return x.reshape(b, -1, h, d) - - with mock.patch.object( - layers.DenseGeneral, '__call__', new=mock_dense_general - ): - _, mutated = layers.MultiHeadDotProductAttention(**args).apply( - {'cache': freeze(cache)}, - inputs_q, - inputs_kv, - decode=True, - mutable=['cache'], - ) - updated_cache = mutated['cache'] - - # Perform the same mocked projection to generate the expected cache. - # (key|value): [b, 1, h, d] - key = mock_dense_general(None, inputs_kv) - value = mock_dense_general(None, inputs_kv) - - # cached_(key|value): [b, h, d, k] - cache['cached_key'][:, :, :, 0] = key[:, 0, :, :] - cache['cached_value'][:, :, :, 0] = value[:, 0, :, :] - cache['cache_index'] = np.array(1) - for name, array in cache.items(): - np.testing.assert_allclose(array, updated_cache[name]) - - def test_dot_product_attention(self): - # b: batch, f: emb_dim, q: q_len, k: kv_len, h: num_head, d: head_dim - b, q, h, d, k = 2, 3, 4, 5, 6 - np.random.seed(0) - query = np.random.randn(b, q, h, d) - key = np.random.randn(b, k, h, d) - value = np.random.randn(b, k, h, d) - bias = np.random.randn(b, h, q, k) - attn_out = layers.dot_product_attention(query, key, value, bias=bias) - logits = np.einsum('bqhd,bkhd->bhqk', query, key) - weights = jax.nn.softmax(logits + bias, axis=-1) - expected = np.einsum('bhqk,bkhd->bqhd', weights, value) - np.testing.assert_allclose(attn_out, expected, atol=1e-6) - - -class EmbeddingTest(parameterized.TestCase): - - def test_embedder_raises_exception_for_incorrect_input_type(self): - """Tests that inputs are integers and that an exception is raised if not.""" - embed = layers.Embed(num_embeddings=10, features=5) - inputs = np.expand_dims(np.arange(5, dtype=np.int64), 1) - variables = embed.init(jax.random.PRNGKey(0), inputs) - bad_inputs = inputs.astype(np.float32) - with self.assertRaisesRegex( - ValueError, 'Input type must be an integer or unsigned integer.' - ): - _ = embed.apply(variables, bad_inputs) - - @parameterized.named_parameters( - { - 'testcase_name': 'with_ones', - 'init_fn': jax.nn.initializers.ones, - 'num_embeddings': 10, - 'features': 5, - 'matrix_sum': 5 * 10, - }, - { - 'testcase_name': 'with_zeros', - 'init_fn': jax.nn.initializers.zeros, - 'num_embeddings': 10, - 'features': 5, - 'matrix_sum': 0, - }, - ) - def test_embedding_initializes_correctly( - self, init_fn, num_embeddings, features, matrix_sum - ): - """Tests if the Embed class initializes with the requested initializer.""" - embed = layers.Embed( - num_embeddings=num_embeddings, features=features, embedding_init=init_fn - ) - inputs = np.expand_dims(np.arange(5, dtype=np.int64), 1) - variables = embed.init(jax.random.PRNGKey(0), inputs) - embedding_matrix = variables['params']['embedding'] - self.assertEqual(int(np.sum(embedding_matrix)), matrix_sum) - - def test_embedding_matrix_shape(self): - """Tests that the embedding matrix has the right shape.""" - num_embeddings = 10 - features = 5 - embed = layers.Embed(num_embeddings=num_embeddings, features=features) - inputs = np.expand_dims(np.arange(features, dtype=np.int64), 1) - variables = embed.init(jax.random.PRNGKey(0), inputs) - embedding_matrix = variables['params']['embedding'] - self.assertEqual((num_embeddings, features), embedding_matrix.shape) - - def test_embedding_attend(self): - """Tests that attending with ones returns sum of embedding vectors.""" - features = 5 - embed = layers.Embed(num_embeddings=10, features=features) - inputs = np.array([[1]], dtype=np.int64) - variables = embed.init(jax.random.PRNGKey(0), inputs) - query = np.ones(features, dtype=np.float32) - result = embed.apply(variables, query, method=embed.attend) - expected = np.sum(variables['params']['embedding'], -1) - np.testing.assert_array_almost_equal(result, expected) - - -class DenseTest(parameterized.TestCase): - - def test_dense_general_no_bias(self): - rng = random.PRNGKey(0) - x = jnp.ones((1, 3)) - model = layers.DenseGeneral( - features=4, - kernel_init=lambda k, s, d, ai, ao: initializers.ones(k, s, d), - ) - y, _ = model.init_with_output(rng, x) - self.assertEqual(y.shape, (1, 4)) - np.testing.assert_allclose(y, np.full((1, 4), 3.0)) - - def test_dense_general_two_features(self): - rng = random.PRNGKey(0) - x = jnp.ones((1, 3)) - model = layers.DenseGeneral( - features=(2, 2), - kernel_init=lambda k, s, d, ai, ao: initializers.ones(k, s, d), - ) - y, _ = model.init_with_output(rng, x) - # We transform the last input dimension to two output dimensions (2, 2). - np.testing.assert_allclose(y, np.full((1, 2, 2), 3.0)) - - def test_dense_general_two_axes(self): - rng = random.PRNGKey(0) - x = jnp.ones((1, 2, 2)) - model = layers.DenseGeneral( - features=3, - axis=(-2, 2), # Note: this is the same as (1, 2). - kernel_init=lambda k, s, d, ai, ao: initializers.ones(k, s, d), - ) - y, _ = model.init_with_output(rng, x) - # We transform the last two input dimensions (2, 2) to one output dimension. - np.testing.assert_allclose(y, np.full((1, 3), 4.0)) - - def test_mlp_same_out_dim(self): - module = layers.MlpBlock( - intermediate_dim=4, - activations=('relu',), - kernel_init=layers.nd_dense_init(1.0, 'fan_avg', 'uniform'), - dtype=jnp.float32, - ) - inputs = np.array( - [ - # Batch 1. - [[1, 1], [1, 1], [1, 2]], - # Batch 2. - [[2, 2], [3, 1], [2, 2]], - ], - dtype=np.float32, - ) - params = module.init(random.PRNGKey(0), inputs, deterministic=True) - # self.assertEqual( - # jax.tree.map(lambda a: a.tolist(), params), - # { - # 'params': { - # 'wi': { - # 'kernel': [ - # [ - # -0.8675811290740967, - # 0.08417510986328125, - # 0.022586345672607422, - # -0.9124102592468262, - # ], - # [ - # -0.19464373588562012, - # 0.49809837341308594, - # 0.7808468341827393, - # 0.9267289638519287, - # ], - # ], - # }, - # 'wo': { - # 'kernel': [ - # [0.01154780387878418, 0.1397249698638916], - # [0.974980354309082, 0.5903260707855225], - # [-0.05997943878173828, 0.616570234298706], - # [0.2934272289276123, 0.8181164264678955], - # ], - # }, - # }, - # 'params_axes': { - # 'wi': { - # 'kernel_axes': AxisMetadata(names=('embed', 'mlp')), - # }, - # 'wo': { - # 'kernel_axes': AxisMetadata(names=('mlp', 'embed')), - # }, - # }, - # }, - # ) - result = module.apply(params, inputs, deterministic=True) # pylint: disable=unused-variable - # np.testing.assert_allclose( - # result.tolist(), - # [ - # [ - # [0.5237172245979309, 0.8508185744285583], - # [0.5237172245979309, 0.8508185744285583], - # [1.2344461679458618, 2.3844780921936035], - # ], - # [ - # [1.0474344491958618, 1.7016371488571167], - # [0.6809444427490234, 0.9663378596305847], - # [1.0474344491958618, 1.7016371488571167], - # ], - # ], - # rtol=1e-6, - # ) - - -class RelativePositionBiasesTest(absltest.TestCase): - - def setUp(self): - self.num_heads = 3 - self.query_len = 5 - self.key_len = 7 - self.relative_attention = layers.RelativePositionBiases( - num_buckets=12, - max_distance=10, - num_heads=3, - dtype=jnp.float32, - ) - super(RelativePositionBiasesTest, self).setUp() - - def test_relative_attention_bidirectional_params(self): - """Tests that bidirectional relative position biases have expected params.""" - params = self.relative_attention.init( - random.PRNGKey(0), self.query_len, self.key_len, bidirectional=True - ) - param_shapes = jax.tree.map(lambda x: x.shape, params) - self.assertEqual( - param_shapes, - { - 'params': { - 'rel_embedding': (3, 12), - }, - 'params_axes': { - 'rel_embedding_axes': AxisMetadata( - names=('heads', 'relpos_buckets') - ), - }, - }, - ) - - def test_regression_relative_attention_bidirectional_values(self): - """Tests that bidirectional relative position biases match expected values. - - See top docstring note on matching T5X behavior for these regression tests. - """ - outputs, unused_params = self.relative_attention.init_with_output( - random.PRNGKey(0), self.query_len, self.key_len, bidirectional=True - ) - self.assertEqual( - outputs.shape, (1, self.num_heads, self.query_len, self.key_len) - ) - # self.assertAlmostEqual(outputs[0, 0, 0, 0], 0.55764728, places=5) - # self.assertAlmostEqual(outputs[0, 1, 2, 1], -0.10935841, places=5) - # self.assertAlmostEqual(outputs[0, 1, 4, 6], 0.14510104, places=5) - # self.assertAlmostEqual(outputs[0, 2, 4, 6], -0.36783996, places=5) - - def test_relative_attention_unidirectional_params(self): - """Tests that unidirectional relative position biases have expected params.""" - params = self.relative_attention.init( - random.PRNGKey(0), self.query_len, self.key_len, bidirectional=False - ) - param_shapes = jax.tree.map(lambda x: x.shape, params) - self.assertEqual( - param_shapes, - { - 'params': { - 'rel_embedding': (3, 12), - }, - 'params_axes': { - 'rel_embedding_axes': AxisMetadata( - names=('heads', 'relpos_buckets') - ), - }, - }, - ) - - def test_regression_relative_attention_unidirectional_values(self): - """Tests that unidirectional relative position biases match expected values. - - See top docstring note on matching T5X behavior for these regression tests. - """ - outputs, unused_params = self.relative_attention.init_with_output( - random.PRNGKey(0), self.query_len, self.key_len, bidirectional=False - ) - self.assertEqual( - outputs.shape, (1, self.num_heads, self.query_len, self.key_len) - ) - # self.assertAlmostEqual(outputs[0, 0, 0, 0], 0.55764728, places=5) - # self.assertAlmostEqual(outputs[0, 1, 2, 1], -0.10935841, places=5) - # self.assertAlmostEqual(outputs[0, 1, 4, 6], -0.13101986, places=5) - # self.assertAlmostEqual(outputs[0, 2, 4, 6], 0.39296466, places=5) - - -if __name__ == '__main__': - absltest.main() diff --git a/t5x-main/t5x/examples/scalable_t5/local_tiny.gin b/t5x-main/t5x/examples/scalable_t5/local_tiny.gin deleted file mode 100644 index 3d7b28429a920a5a595afdf897f9525a8f9e1487..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/scalable_t5/local_tiny.gin +++ /dev/null @@ -1,70 +0,0 @@ -# A gin file to make the Transformer models tiny for faster local testing. -# -# When testing locally with CPU, there are a few things that we need. -# - tiny model size -# - small enough batch size -# - small sequence length -# - determinstic dataset pipeline -# -# This gin file adds such configs. To use this gin file, add it on top of the -# existing full-scale gin files. The ordering of the gin file matters. So this -# should be added after all the other files are added to override the same -# configurables. - -from __gin__ import dynamic_registration - -from t5x import partitioning -from t5x import trainer -from t5x import utils -from t5x.examples.t5 import network - -import __main__ as train_script - -train_script.train.random_seed = 42 # dropout seed -train/utils.DatasetConfig.seed = 42 # dataset seed - -TASK_FEATURE_LENGTHS = {"inputs": 8, "targets": 7} -LABEL_SMOOTHING = 0.0 - -# Network specification overrides -network.Transformer.config = @network.T5Config() -network.T5Config: - vocab_size = 32128 # vocab size rounded to a multiple of 128 for TPU efficiency - dtype = 'bfloat16' - emb_dim = 8 - num_heads = 4 - num_encoder_layers = 2 - num_decoder_layers = 2 - head_dim = 3 - mlp_dim = 16 - mlp_activations = ('gelu', 'linear') - dropout_rate = 0.0 - logits_via_embedding = False - scan_layers = True - remat_policy = 'minimal' - -TRAIN_STEPS = 3 - -train/utils.DatasetConfig: - batch_size = 8 - shuffle = False - -train_eval/utils.DatasetConfig.batch_size = 8 - -train_script.train: - eval_period = 3 - eval_steps = 3 - -trainer.Trainer.num_microbatches = 0 -partitioning.PjitPartitioner: - num_partitions = 1 - model_parallel_submesh = None - -utils.CheckpointConfig: - restore = None - -infer_eval/utils.DatasetConfig.task_feature_lengths = %TASK_FEATURE_LENGTHS - - -# DISABLE INFERENCE EVAL -# train_script.train.infer_eval_dataset_cfg = None diff --git a/t5x-main/t5x/examples/scalable_t5/mt5/__init__.py b/t5x-main/t5x/examples/scalable_t5/mt5/__init__.py deleted file mode 100644 index 0146a9b5441d4b6476a922553e6def9c1e3b598c..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/scalable_t5/mt5/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This empty file is needed for loading the gin files in this directory. diff --git a/t5x-main/t5x/examples/scalable_t5/mt5/base.gin b/t5x-main/t5x/examples/scalable_t5/mt5/base.gin deleted file mode 100644 index b113eb4d4ac3f16dde2ad5debda6ad106ea2769e..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/scalable_t5/mt5/base.gin +++ /dev/null @@ -1,58 +0,0 @@ -# MT5 Base model. -from __gin__ import dynamic_registration - -import seqio -from t5x import adafactor -from t5x import models -from t5x.examples.scalable_t5 import network - -# ------------------- Loss HParam ---------------------------------------------- -Z_LOSS = 0.0001 -LABEL_SMOOTHING = 0.0 -# NOTE: When fine-tuning the public T5 checkpoints (trained in T5 MeshTF) -# the loss normalizing factor should be set to pretraining batch_size * -# target_token_length. -LOSS_NORMALIZING_FACTOR = None -# Dropout should be specified in the "run" files -DROPOUT_RATE = %gin.REQUIRED - -# Vocabulary (shared by encoder and decoder) -VOCABULARY = @seqio.SentencePieceVocabulary() -seqio.SentencePieceVocabulary.sentencepiece_model_file = "gs://t5-data/vocabs/mc4.250000.100extra/sentencepiece.model" - - -# ------------------- Optimizer ------------------------------------------------ -# `learning_rate` is set by `Trainer.learning_rate_fn`. -OPTIMIZER = @adafactor.Adafactor() -adafactor.Adafactor: - decay_rate = 0.8 - step_offset = 0 - logical_factor_rules = @adafactor.standard_logical_factor_rules() - -# ------------------- Model ---------------------------------------------------- -MODEL = @models.EncoderDecoderModel() -models.EncoderDecoderModel: - module = @network.Transformer() - input_vocabulary = %VOCABULARY - output_vocabulary = %VOCABULARY - optimizer_def = %OPTIMIZER - z_loss = %Z_LOSS - label_smoothing = %LABEL_SMOOTHING - loss_normalizing_factor = %LOSS_NORMALIZING_FACTOR - -# ------------------- Network specification ------------------------------------ -network.Transformer.config = @network.T5Config() -network.T5Config: - vocab_size = 250112 # vocab size rounded to a multiple of 128 for TPU efficiency - dtype = 'bfloat16' - emb_dim = 768 - num_heads = 12 - num_encoder_layers = 12 - num_decoder_layers = 12 - head_dim = 64 - mlp_dim = 2048 - mlp_activations = ('gelu', 'linear') - dropout_rate = %DROPOUT_RATE - logits_via_embedding = False - scan_layers = True - remat_policy = 'minimal' diff --git a/t5x-main/t5x/examples/scalable_t5/mt5/large.gin b/t5x-main/t5x/examples/scalable_t5/mt5/large.gin deleted file mode 100644 index ff5a70ca2034003443f509855ca7f0348f59d6fd..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/scalable_t5/mt5/large.gin +++ /dev/null @@ -1,13 +0,0 @@ -# T5.1.1 Large model. - -include 't5x/examples/scalable_t5/mt5/base.gin' # imports vocab, optimizer and model. - -# ------------------- Network specification overrides -------------------------- -network.Transformer.config = @network.T5Config() -network.T5Config: - emb_dim = 1024 - num_heads = 16 - num_encoder_layers = 24 - num_decoder_layers = 24 - head_dim = 64 - mlp_dim = 2816 diff --git a/t5x-main/t5x/examples/scalable_t5/mt5/small.gin b/t5x-main/t5x/examples/scalable_t5/mt5/small.gin deleted file mode 100644 index aed9f8412f8c6d812a79dbfa686d1ee0a9c8442d..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/scalable_t5/mt5/small.gin +++ /dev/null @@ -1,13 +0,0 @@ -# T5.1.1 Small model. - -include 't5x/examples/scalable_t5/mt5/base.gin' # imports vocab, optimizer and model. - -# ------------------- Network specification overrides -------------------------- -network.Transformer.config = @network.T5Config() -network.T5Config: - emb_dim = 512 - num_heads = 6 - num_encoder_layers = 8 - num_decoder_layers = 8 - head_dim = 64 - mlp_dim = 1024 diff --git a/t5x-main/t5x/examples/scalable_t5/mt5/xl.gin b/t5x-main/t5x/examples/scalable_t5/mt5/xl.gin deleted file mode 100644 index 37e6730921626ecda53788410178e94bc4930651..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/scalable_t5/mt5/xl.gin +++ /dev/null @@ -1,13 +0,0 @@ -# T5.1.1 XL model. - -include 't5x/examples/scalable_t5/mt5/base.gin' # imports vocab, optimizer and model. - -# ------------------- Network specification overrides -------------------------- -network.Transformer.config = @network.T5Config() -network.T5Config: - emb_dim = 2048 - num_heads = 32 - num_encoder_layers = 24 - num_decoder_layers = 24 - head_dim = 64 - mlp_dim = 5120 diff --git a/t5x-main/t5x/examples/scalable_t5/mt5/xxl.gin b/t5x-main/t5x/examples/scalable_t5/mt5/xxl.gin deleted file mode 100644 index 135ee52c08e5ca42ff18232c7b1556e053cf279a..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/scalable_t5/mt5/xxl.gin +++ /dev/null @@ -1,13 +0,0 @@ -# T5.1.1 XXL model. - -include 't5x/examples/scalable_t5/mt5/base.gin' # imports vocab, optimizer and model. - -# ------------------- Network specification overrides -------------------------- -network.Transformer.config = @network.T5Config() -network.T5Config: - emb_dim = 4096 - num_heads = 64 - num_encoder_layers = 24 - num_decoder_layers = 24 - head_dim = 64 - mlp_dim = 10240 diff --git a/t5x-main/t5x/examples/scalable_t5/network.py b/t5x-main/t5x/examples/scalable_t5/network.py deleted file mode 100644 index b20a7c56d3800f64d4957127cf35959ea71a4cf9..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/scalable_t5/network.py +++ /dev/null @@ -1,559 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""T5.1.1 Transformer model.""" - -from typing import Any, Sequence - -from flax import linen as nn -from flax import struct -from flax.linen import partitioning as nn_partitioning -import jax -import jax.numpy as jnp -from t5x.examples.scalable_t5 import layers - -with_sharding_constraint = nn_partitioning.with_sharding_constraint -scan_with_axes = nn_partitioning.scan_with_axes -remat = nn_partitioning.remat -ScanIn = nn_partitioning.ScanIn - - -@struct.dataclass -class T5Config: - """Global hyperparameters used to minimize obnoxious kwarg plumbing.""" - - vocab_size: int - # Activation dtypes. - dtype: Any = jnp.float32 - emb_dim: int = 512 - num_heads: int = 8 - num_encoder_layers: int = 6 - num_decoder_layers: int = 6 - head_dim: int = 64 - mlp_dim: int = 2048 - # Activation functions are retrieved from Flax. - mlp_activations: Sequence[str] = ('relu',) - dropout_rate: float = 0.1 - # If `True`, the embedding weights are used in the decoder output layer. - logits_via_embedding: bool = False - # minimal, full, or none - remat_policy: str = 'none' - scan_layers: bool = True - param_scan_axis: int = 1 - - -class EncoderLayer(nn.Module): - """Transformer encoder layer.""" - - config: T5Config - - @nn.compact - def __call__(self, inputs, encoder_mask=None, deterministic=False): - cfg = self.config - - # Relative position embedding as attention biases. - encoder_bias = layers.RelativePositionBiases( - num_buckets=32, - max_distance=128, - num_heads=cfg.num_heads, - dtype=cfg.dtype, - embedding_init=nn.initializers.variance_scaling( - 1.0, 'fan_avg', 'uniform' - ), - name='relpos_bias', - )(inputs.shape[-2], inputs.shape[-2], True) - - # Attention block. - assert inputs.ndim == 3 - inputs = with_sharding_constraint(inputs, ('batch', 'length', 'embed')) - x = layers.LayerNorm(dtype=cfg.dtype, name='pre_attention_layer_norm')( - inputs - ) - x = with_sharding_constraint(x, ('batch', 'length', 'embed')) - # [batch, length, emb_dim] -> [batch, length, emb_dim] - x = layers.MultiHeadDotProductAttention( - num_heads=cfg.num_heads, - dtype=cfg.dtype, - head_dim=cfg.head_dim, - dropout_rate=cfg.dropout_rate, - name='attention', - )(x, x, encoder_mask, encoder_bias, deterministic=deterministic) - x = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))( - x, deterministic=deterministic - ) - x = x + inputs - x = with_sharding_constraint(x, ('batch', 'length', 'embed')) - - # MLP block. - y = layers.LayerNorm(dtype=cfg.dtype, name='pre_mlp_layer_norm')(x) - y = with_sharding_constraint(y, ('batch', 'length', 'embed')) - # [batch, length, emb_dim] -> [batch, length, emb_dim] - y = layers.MlpBlock( - intermediate_dim=cfg.mlp_dim, - activations=cfg.mlp_activations, - intermediate_dropout_rate=cfg.dropout_rate, - dtype=cfg.dtype, - name='mlp', - )(y, deterministic=deterministic) - y = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))( - y, deterministic=deterministic - ) - y = y + x - y = with_sharding_constraint(y, ('batch', 'length', 'embed')) - - if cfg.scan_layers: - return y, None - else: - return y - - -class DecoderLayer(nn.Module): - """Transformer decoder layer that attends to the encoder.""" - - config: T5Config - - @nn.compact - def __call__( - self, - inputs, - encoded, - decoder_mask=None, - encoder_decoder_mask=None, - deterministic=False, - decode=False, - max_decode_length=None, - ): - cfg = self.config - - # Relative position embedding as attention biases. - l = max_decode_length if decode and max_decode_length else inputs.shape[-2] - decoder_bias = layers.RelativePositionBiases( - num_buckets=32, - max_distance=128, - num_heads=cfg.num_heads, - dtype=cfg.dtype, - embedding_init=nn.initializers.variance_scaling( - 1.0, 'fan_avg', 'uniform' - ), - name='relpos_bias', - )(l, l, False) - - inputs = with_sharding_constraint(inputs, ('batch', 'length', 'embed')) - - # inputs: embedded inputs to the decoder with shape [batch, length, emb_dim] - x = layers.LayerNorm(dtype=cfg.dtype, name='pre_self_attention_layer_norm')( - inputs - ) - x = with_sharding_constraint(x, ('batch', 'length', 'embed')) - - # Self-attention block - x = layers.MultiHeadDotProductAttention( - num_heads=cfg.num_heads, - dtype=cfg.dtype, - head_dim=cfg.head_dim, - dropout_rate=cfg.dropout_rate, - name='self_attention', - )( - x, - x, - decoder_mask, - decoder_bias, - deterministic=deterministic, - decode=decode, - ) - x = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))( - x, deterministic=deterministic - ) - x = x + inputs - x = with_sharding_constraint(x, ('batch', 'length', 'embed')) - - # Encoder-Decoder block. - y = layers.LayerNorm( - dtype=cfg.dtype, name='pre_cross_attention_layer_norm' - )(x) - y = with_sharding_constraint(y, ('batch', 'length', 'embed')) - y = layers.MultiHeadDotProductAttention( - num_heads=cfg.num_heads, - dtype=cfg.dtype, - head_dim=cfg.head_dim, - dropout_rate=cfg.dropout_rate, - name='encoder_decoder_attention', - )(y, encoded, encoder_decoder_mask, deterministic=deterministic) - y = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))( - y, deterministic=deterministic - ) - y = y + x - y = with_sharding_constraint(y, ('batch', 'length', 'embed')) - - # MLP block. - z = layers.LayerNorm(dtype=cfg.dtype, name='pre_mlp_layer_norm')(y) - z = with_sharding_constraint(z, ('batch', 'length', 'embed')) - z = layers.MlpBlock( - intermediate_dim=cfg.mlp_dim, - activations=cfg.mlp_activations, - intermediate_dropout_rate=cfg.dropout_rate, - dtype=cfg.dtype, - name='mlp', - )(z, deterministic=deterministic) - z = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))( - z, deterministic=deterministic - ) - z = z + y - z = with_sharding_constraint(z, ('batch', 'length', 'embed')) - - if cfg.scan_layers: - return z, None - else: - return z - - -class Encoder(nn.Module): - """A stack of encoder layers.""" - - config: T5Config - shared_embedding: nn.Module - - @nn.compact - def __call__( - self, encoder_input_tokens, encoder_mask=None, deterministic=False - ): - cfg = self.config - assert encoder_input_tokens.ndim == 2 # [batch, length] - - # [batch, length] -> [batch, length, emb_dim] - x = self.shared_embedding(encoder_input_tokens.astype('int32')) - x = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))( - x, deterministic=deterministic - ) - x = x.astype(cfg.dtype) - - BlockLayer = EncoderLayer - - if cfg.remat_policy not in (None, 'none'): - if cfg.remat_policy == 'minimal': - policy = jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims - else: - policy = None - BlockLayer = remat( # pylint: disable=invalid-name - BlockLayer, - prevent_cse=not cfg.scan_layers, - policy=policy, - static_argnums=(2,), - ) - - if cfg.scan_layers: - initializing = self.is_mutable_collection('params') - params_spec = ( - cfg.param_scan_axis if initializing else ScanIn(cfg.param_scan_axis) - ) - cache_spec = 0 - x, _ = scan_with_axes( - BlockLayer, - variable_axes={ - 'params': params_spec, - 'cache': cache_spec, - }, - split_rngs={'params': True, 'dropout': True}, - in_axes=(nn.broadcast, nn.broadcast), - length=cfg.num_encoder_layers, - axis_name='layers', - )(config=cfg, name='encoder')(x, encoder_mask, deterministic) - else: - for lyr in range(cfg.num_encoder_layers): - # [batch, length, emb_dim] -> [batch, length, emb_dim] - x = BlockLayer(config=cfg, name=f'layers_{lyr}')( - x, encoder_mask, deterministic - ) - - x = layers.LayerNorm(dtype=cfg.dtype, name='encoder_norm')(x) - return nn.Dropout(rate=cfg.dropout_rate)(x, deterministic=deterministic) - - -class Decoder(nn.Module): - """A stack of decoder layers as a part of an encoder-decoder architecture.""" - - config: T5Config - shared_embedding: nn.Module - - @nn.compact - def __call__( - self, - encoded, - decoder_input_tokens, - decoder_positions=None, - decoder_mask=None, - encoder_decoder_mask=None, - deterministic=False, - decode=False, - max_decode_length=None, - ): - cfg = self.config - assert decoder_input_tokens.ndim == 2 # [batch, len] - - # [batch, length] -> [batch, length, emb_dim] - y = self.shared_embedding(decoder_input_tokens.astype('int32')) - y = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))( - y, deterministic=deterministic - ) - y = y.astype(cfg.dtype) - - BlockLayer = DecoderLayer - - if cfg.remat_policy not in (None, 'none'): - if cfg.remat_policy == 'minimal': - policy = jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims - else: - policy = None - BlockLayer = remat( # pylint: disable=invalid-name - BlockLayer, - prevent_cse=not cfg.scan_layers, - policy=policy, - static_argnums=(4, 5, 6), - ) - if cfg.scan_layers: - initializing = self.is_mutable_collection('params') - params_spec = ( - cfg.param_scan_axis if initializing else ScanIn(cfg.param_scan_axis) - ) - cache_spec = 0 - y, _ = scan_with_axes( - BlockLayer, - variable_axes={'params': params_spec, 'cache': cache_spec}, - split_rngs={'params': True, 'dropout': True}, - in_axes=( - nn.broadcast, - nn.broadcast, - nn.broadcast, - nn.broadcast, - nn.broadcast, - nn.broadcast, - ), - length=cfg.num_decoder_layers, - axis_name='layers', - )(config=cfg, name='decoder')( - y, - encoded, - decoder_mask, - encoder_decoder_mask, - deterministic, - decode, - max_decode_length, - ) - else: - for lyr in range(cfg.num_decoder_layers): - # [batch, length, emb_dim] -> [batch, length, emb_dim] - y = BlockLayer(config=cfg, name=f'layers_{lyr}')( - y, - encoded, - decoder_mask=decoder_mask, - encoder_decoder_mask=encoder_decoder_mask, - deterministic=deterministic, - decode=decode, - max_decode_length=max_decode_length, - ) - - y = layers.LayerNorm(dtype=cfg.dtype, name='decoder_norm')(y) - y = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))( - y, deterministic=deterministic - ) - - # [batch, length, emb_dim] -> [batch, length, vocab_size] - if cfg.logits_via_embedding: - # Use the transpose of embedding matrix for logit transform. - logits = self.shared_embedding.attend(y) - # Correctly normalize pre-softmax logits for this shared case. - logits = logits / jnp.sqrt(y.shape[-1]) - else: - logits = layers.DenseGeneral( - cfg.vocab_size, - dtype=jnp.float32, # Use float32 for stabiliity. - kernel_axes=('embed', 'vocab'), - name='logits_dense', - )(y) - return logits - - -class Transformer(nn.Module): - """An encoder-decoder Transformer model.""" - - config: T5Config - # needed only for janky models.py scan_layers detection. - scan_layers: bool = struct.field(init=False) - - def __post_init__(self): - super().__post_init__() - # needed only for janky models.py scan_layers detection. - object.__setattr__( - self, 'scan_layers', object.__getattribute__(self, 'config').scan_layers - ) - - def setup(self): - cfg = self.config - self.shared_embedding = layers.Embed( - num_embeddings=cfg.vocab_size, - features=cfg.emb_dim, - dtype=cfg.dtype, - attend_dtype=jnp.float32, # for logit training stability - embedding_init=nn.initializers.normal(stddev=1.0), - one_hot=True, - name='token_embedder', - ) - - self.encoder = Encoder(config=cfg, shared_embedding=self.shared_embedding) - self.decoder = Decoder(config=cfg, shared_embedding=self.shared_embedding) - - def encode( - self, encoder_input_tokens, encoder_segment_ids=None, enable_dropout=True - ): - """Applies Transformer encoder-branch on the inputs.""" - cfg = self.config - assert encoder_input_tokens.ndim == 2 # (batch, len) - - # Make padding attention mask. - encoder_mask = layers.make_attention_mask( - encoder_input_tokens > 0, encoder_input_tokens > 0, dtype=cfg.dtype - ) - # Add segmentation block-diagonal attention mask if using segmented data. - if encoder_segment_ids is not None: - encoder_mask = layers.combine_masks( - encoder_mask, - layers.make_attention_mask( - encoder_segment_ids, - encoder_segment_ids, - jnp.equal, - dtype=cfg.dtype, - ), - ) - - return self.encoder( - encoder_input_tokens, encoder_mask, deterministic=not enable_dropout - ) - - def decode( - self, - encoded, - encoder_input_tokens, # only needed for masks - decoder_input_tokens, - decoder_target_tokens, - encoder_segment_ids=None, - decoder_segment_ids=None, - decoder_positions=None, - enable_dropout=True, - decode=False, - max_decode_length=None, - ): - """Applies Transformer decoder-branch on encoded-input and target.""" - cfg = self.config - - # Make padding attention masks. - if decode: - # Do not mask decoder attention based on targets padding at - # decoding/inference time. - decoder_mask = None - encoder_decoder_mask = layers.make_attention_mask( - jnp.ones_like(decoder_target_tokens), - encoder_input_tokens > 0, - dtype=cfg.dtype, - ) - else: - decoder_mask = layers.make_decoder_mask( - decoder_target_tokens=decoder_target_tokens, - dtype=cfg.dtype, - decoder_segment_ids=decoder_segment_ids, - ) - encoder_decoder_mask = layers.make_attention_mask( - decoder_target_tokens > 0, encoder_input_tokens > 0, dtype=cfg.dtype - ) - - # Add segmentation block-diagonal attention masks if using segmented data. - if encoder_segment_ids is not None: - if decode: - raise ValueError( - 'During decoding, packing should not be used but ' - '`encoder_segment_ids` was passed to `Transformer.decode`.' - ) - - encoder_decoder_mask = layers.combine_masks( - encoder_decoder_mask, - layers.make_attention_mask( - decoder_segment_ids, - encoder_segment_ids, - jnp.equal, - dtype=cfg.dtype, - ), - ) - - logits = self.decoder( - encoded, - decoder_input_tokens=decoder_input_tokens, - decoder_positions=decoder_positions, - decoder_mask=decoder_mask, - encoder_decoder_mask=encoder_decoder_mask, - deterministic=not enable_dropout, - decode=decode, - max_decode_length=max_decode_length, - ) - return logits - - def __call__( - self, - encoder_input_tokens, - decoder_input_tokens, - decoder_target_tokens, - encoder_segment_ids=None, - decoder_segment_ids=None, - encoder_positions=None, - decoder_positions=None, - *, - enable_dropout: bool = True, - decode: bool = False, - ): - """Applies Transformer model on the inputs. - - This method requires both decoder_target_tokens and decoder_input_tokens, - which is a shifted version of the former. For a packed dataset, it usually - has additional processing applied. For example, the first element of each - sequence has id 0 instead of the shifted EOS id from the previous sequence. - - Args: - encoder_input_tokens: input data to the encoder. - decoder_input_tokens: input token to the decoder. - decoder_target_tokens: target token to the decoder. - encoder_segment_ids: encoder segmentation info for packed examples. - decoder_segment_ids: decoder segmentation info for packed examples. - encoder_positions: encoder subsequence positions for packed examples. - decoder_positions: decoder subsequence positions for packed examples. - enable_dropout: Ensables dropout if set to True. - decode: Whether to prepare and use an autoregressive cache. - - Returns: - logits array from full transformer. - """ - encoded = self.encode( - encoder_input_tokens, - encoder_segment_ids=encoder_segment_ids, - enable_dropout=enable_dropout, - ) - - return self.decode( - encoded, - encoder_input_tokens, # only used for masks - decoder_input_tokens, - decoder_target_tokens, - encoder_segment_ids=encoder_segment_ids, - decoder_segment_ids=decoder_segment_ids, - decoder_positions=decoder_positions, - enable_dropout=enable_dropout, - decode=decode, - ) diff --git a/t5x-main/t5x/examples/scalable_t5/network_test.py b/t5x-main/t5x/examples/scalable_t5/network_test.py deleted file mode 100644 index a4a584872d69882f09acce9c390f921d6e0ab76d..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/scalable_t5/network_test.py +++ /dev/null @@ -1,112 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for network.""" - -import os - -from absl import flags -from absl.testing import absltest -from absl.testing import parameterized -import jax -import numpy as np -import seqio -from t5x import adafactor -from t5x import models -from t5x import test_utils -from t5x.examples.scalable_t5 import network - -# Parse absl flags test_srcdir and test_tmpdir. -jax.config.parse_flags_with_absl() - -FLAGS = flags.FLAGS - - -def get_test_model( - emb_dim, - head_dim, - num_heads, - mlp_dim, - dtype='float32', - vocab_size=32128, - num_encoder_layers=2, - num_decoder_layers=2, -): - config = network.T5Config( - num_encoder_layers=num_encoder_layers, - num_decoder_layers=num_decoder_layers, - vocab_size=vocab_size, - dropout_rate=0, - emb_dim=emb_dim, - num_heads=num_heads, - head_dim=head_dim, - mlp_dim=mlp_dim, - dtype=dtype, - mlp_activations=('gelu', 'linear'), - ) - module = network.Transformer(config=config) - vocab = seqio.test_utils.sentencepiece_vocab() - optimizer_def = adafactor.Adafactor() - return models.EncoderDecoderModel( - module, vocab, vocab, optimizer_def=optimizer_def - ) - - -class NetworkTest(parameterized.TestCase): - - def setUp(self): - super().setUp() - batch_size, max_decode_len, input_len = 2, 3, 4 - self.input_shapes = { - 'encoder_input_tokens': (batch_size, input_len), - 'decoder_input_tokens': (batch_size, max_decode_len), - } - np.random.seed(42) - self.batch = { - 'encoder_input_tokens': np.random.randint( - 3, 10, size=(batch_size, input_len) - ), - 'decoder_input_tokens': np.random.randint( - 3, 10, size=(batch_size, max_decode_len) - ), - 'decoder_target_tokens': np.random.randint( - 3, 10, size=(batch_size, max_decode_len) - ), - } - - def test_regression(self): - model = get_test_model( - emb_dim=13, - head_dim=64, - num_heads=8, - mlp_dim=2048, - vocab_size=10, - num_encoder_layers=3, - ) - params = model.get_initial_variables( - jax.random.PRNGKey(0), self.input_shapes - )['params'] - loss, _ = model.loss_fn(params, self.batch, jax.random.PRNGKey(1)) # pylint: disable=unused-variable - - # self.assertAlmostEqual(loss, 16.45335, delta=0.05) - # predicted, scores = model.predict_batch_with_aux(params, self.batch) - # np.testing.assert_array_equal(predicted, [[7, 1, 0], [7, 1, 0]]) - # np.testing.assert_allclose( - # scores['scores'], [-1.240393, -2.035653], rtol=1e-2 - # ) - - - -if __name__ == '__main__': - absltest.main() diff --git a/t5x-main/t5x/examples/scalable_t5/t5_1_1/__init__.py b/t5x-main/t5x/examples/scalable_t5/t5_1_1/__init__.py deleted file mode 100644 index 0146a9b5441d4b6476a922553e6def9c1e3b598c..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/scalable_t5/t5_1_1/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This empty file is needed for loading the gin files in this directory. diff --git a/t5x-main/t5x/examples/scalable_t5/t5_1_1/base.gin b/t5x-main/t5x/examples/scalable_t5/t5_1_1/base.gin deleted file mode 100644 index ebab93a6792375c3a58daff9cb0a27deff4ea1bb..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/scalable_t5/t5_1_1/base.gin +++ /dev/null @@ -1,57 +0,0 @@ -# T5.1.1 Base model. -from __gin__ import dynamic_registration - -import seqio -from t5x import adafactor -from t5x import models -from t5x.examples.scalable_t5 import network - -# ------------------- Loss HParam ---------------------------------------------- -Z_LOSS = 0.0001 -LABEL_SMOOTHING = 0.0 -# NOTE: When fine-tuning the public T5 checkpoints (trained in T5 MeshTF) -# the loss normalizing factor should be set to pretraining batch_size * -# target_token_length. -LOSS_NORMALIZING_FACTOR = None -# Dropout should be specified in the "run" files -DROPOUT_RATE = %gin.REQUIRED - -# Vocabulary (shared by encoder and decoder) -VOCABULARY = @seqio.SentencePieceVocabulary() -seqio.SentencePieceVocabulary.sentencepiece_model_file = "gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model" - -# ------------------- Optimizer ------------------------------------------------ -# `learning_rate` is set by `Trainer.learning_rate_fn`. -OPTIMIZER = @adafactor.Adafactor() -adafactor.Adafactor: - decay_rate = 0.8 - step_offset = 0 - logical_factor_rules = @adafactor.standard_logical_factor_rules() - -# ------------------- Model ---------------------------------------------------- -MODEL = @models.EncoderDecoderModel() -models.EncoderDecoderModel: - module = @network.Transformer() - input_vocabulary = %VOCABULARY - output_vocabulary = %VOCABULARY - optimizer_def = %OPTIMIZER - z_loss = %Z_LOSS - label_smoothing = %LABEL_SMOOTHING - loss_normalizing_factor = %LOSS_NORMALIZING_FACTOR - -# ------------------- Network specification ------------------------------------ -network.Transformer.config = @network.T5Config() -network.T5Config: - vocab_size = 32128 # vocab size rounded to a multiple of 128 for TPU efficiency - dtype = 'bfloat16' - emb_dim = 768 - num_heads = 12 - num_encoder_layers = 12 - num_decoder_layers = 12 - head_dim = 64 - mlp_dim = 2048 - mlp_activations = ('gelu', 'linear') - dropout_rate = %DROPOUT_RATE - logits_via_embedding = False - scan_layers = True - remat_policy = 'minimal' diff --git a/t5x-main/t5x/examples/scalable_t5/t5_1_1/examples/__init__.py b/t5x-main/t5x/examples/scalable_t5/t5_1_1/examples/__init__.py deleted file mode 100644 index 0146a9b5441d4b6476a922553e6def9c1e3b598c..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/scalable_t5/t5_1_1/examples/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This empty file is needed for loading the gin files in this directory. diff --git a/t5x-main/t5x/examples/scalable_t5/t5_1_1/examples/wmt19_ende_from_scratch.gin b/t5x-main/t5x/examples/scalable_t5/t5_1_1/examples/wmt19_ende_from_scratch.gin deleted file mode 100644 index 1d75be863781c66a324b47c311ac8ac04c205da7..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/scalable_t5/t5_1_1/examples/wmt19_ende_from_scratch.gin +++ /dev/null @@ -1,62 +0,0 @@ -from __gin__ import dynamic_registration - -import __main__ as train_script -from t5x import adafactor -from t5x import models -from t5x import partitioning -from t5x import trainer -from t5x import utils -from t5x.examples.scalable_t5 import network - -include "t5x/examples/scalable_t5/t5_1_1/base.gin" -include "t5x/configs/runs/finetune.gin" - -MIXTURE_OR_TASK_NAME = "wmt19_ende_v003" -MIXTURE_OR_TASK_MODULE = "t5.data.mixtures" -TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 512} -TRAIN_STEPS = 25000 -LABEL_SMOOTHING = 0.1 -INITIAL_CHECKPOINT_PATH = None -# Note that `DROPOUT_RATE = 0.1` is specified in the finetune.gin but we just -# repeat to make it explicit. -DROPOUT_RATE = 0.1 - -train/utils.DatasetConfig: - batch_size = 128 - use_cached = False - pack = True - use_custom_packing_ops = False - seed = 0 - -train_eval/utils.DatasetConfig: - batch_size = 128 - use_cached = False - pack = False - use_custom_packing_ops = False - seed = 0 - -infer_eval/utils.DatasetConfig: - use_cached = False - -train_script.train: - eval_period = 250 - eval_steps = 20 - random_seed = 0 - use_hardware_rng = True - -utils.CheckpointConfig.restore = None -utils.SaveCheckpointConfig: - period = 500 # checkpoint frequency - keep = 1 - -# Decoder overrides -models.EncoderDecoderModel.predict_batch_with_aux.num_decodes = 4 - -trainer.Trainer.num_microbatches = 2 -utils.create_learning_rate_scheduler.warmup_steps = 1000 - -partitioning.PjitPartitioner: - model_parallel_submesh = (1, 1, 1, 2) - -adafactor.Adafactor: - logical_factor_rules = @adafactor.standard_logical_factor_rules() diff --git a/t5x-main/t5x/examples/scalable_t5/t5_1_1/large.gin b/t5x-main/t5x/examples/scalable_t5/t5_1_1/large.gin deleted file mode 100644 index b01f319967d7a1c39f58beaa45f012e2b65de9db..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/scalable_t5/t5_1_1/large.gin +++ /dev/null @@ -1,13 +0,0 @@ -# T5.1.1 Large model. - -include 't5x/examples/scalable_t5/t5_1_1/base.gin' # imports vocab, optimizer and model. - -# ------------------- Network specification overrides -------------------------- -network.Transformer.config = @network.T5Config() -network.T5Config: - emb_dim = 1024 - num_heads = 16 - num_encoder_layers = 24 - num_decoder_layers = 24 - head_dim = 64 - mlp_dim = 2816 diff --git a/t5x-main/t5x/examples/scalable_t5/t5_1_1/small.gin b/t5x-main/t5x/examples/scalable_t5/t5_1_1/small.gin deleted file mode 100644 index d1a8005c66994f2951c67fa65c91c1a77d86e576..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/scalable_t5/t5_1_1/small.gin +++ /dev/null @@ -1,13 +0,0 @@ -# T5.1.1 Small model. - -include 't5x/examples/scalable_t5/t5_1_1/base.gin' # imports vocab, optimizer and model. - -# ------------------- Network specification overrides -------------------------- -network.Transformer.config = @network.T5Config() -network.T5Config: - emb_dim = 512 - num_heads = 6 - num_encoder_layers = 8 - num_decoder_layers = 8 - head_dim = 64 - mlp_dim = 1024 diff --git a/t5x-main/t5x/examples/scalable_t5/t5_1_1/xl.gin b/t5x-main/t5x/examples/scalable_t5/t5_1_1/xl.gin deleted file mode 100644 index d8d98b4d55eee083b17852042296233bf8c6bbc5..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/scalable_t5/t5_1_1/xl.gin +++ /dev/null @@ -1,13 +0,0 @@ -# T5.1.1 XL model. - -include 't5x/examples/scalable_t5/t5_1_1/base.gin' # imports vocab, optimizer and model. - -# ------------------- Network specification overrides -------------------------- -network.Transformer.config = @network.T5Config() -network.T5Config: - emb_dim = 2048 - num_heads = 32 - num_encoder_layers = 24 - num_decoder_layers = 24 - head_dim = 64 - mlp_dim = 5120 diff --git a/t5x-main/t5x/examples/scalable_t5/t5_1_1/xxl.gin b/t5x-main/t5x/examples/scalable_t5/t5_1_1/xxl.gin deleted file mode 100644 index 8ed37fe9209349dca2fe146675d05fb4a1f8eb8e..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/scalable_t5/t5_1_1/xxl.gin +++ /dev/null @@ -1,13 +0,0 @@ -# T5.1.1 XXL model. - -include 't5x/examples/scalable_t5/t5_1_1/base.gin' # imports vocab, optimizer and model. - -# ------------------- Network specification overrides -------------------------- -network.Transformer.config = @network.T5Config() -network.T5Config: - emb_dim = 4096 - num_heads = 64 - num_encoder_layers = 24 - num_decoder_layers = 24 - head_dim = 64 - mlp_dim = 10240 diff --git a/t5x-main/t5x/examples/scalable_t5/umt5/__init__.py b/t5x-main/t5x/examples/scalable_t5/umt5/__init__.py deleted file mode 100644 index 0146a9b5441d4b6476a922553e6def9c1e3b598c..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/scalable_t5/umt5/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This empty file is needed for loading the gin files in this directory. diff --git a/t5x-main/t5x/examples/scalable_t5/umt5/architectures/__init__.py b/t5x-main/t5x/examples/scalable_t5/umt5/architectures/__init__.py deleted file mode 100644 index 0146a9b5441d4b6476a922553e6def9c1e3b598c..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/scalable_t5/umt5/architectures/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This empty file is needed for loading the gin files in this directory. diff --git a/t5x-main/t5x/examples/scalable_t5/umt5/architectures/encoder_decoder.gin b/t5x-main/t5x/examples/scalable_t5/umt5/architectures/encoder_decoder.gin deleted file mode 100644 index f264408216eed94e43d51908ffc212f8f6c639a1..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/scalable_t5/umt5/architectures/encoder_decoder.gin +++ /dev/null @@ -1,4 +0,0 @@ -from t5x.examples.scalable_t5 import network - -# This macro should be set in the vocabulary gin file -network.T5Config.vocab_size = %VOCAB_SIZE diff --git a/t5x-main/t5x/examples/scalable_t5/umt5/optimizer/__init__.py b/t5x-main/t5x/examples/scalable_t5/umt5/optimizer/__init__.py deleted file mode 100644 index 0146a9b5441d4b6476a922553e6def9c1e3b598c..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/scalable_t5/umt5/optimizer/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This empty file is needed for loading the gin files in this directory. diff --git a/t5x-main/t5x/examples/scalable_t5/umt5/optimizer/adafactor.gin b/t5x-main/t5x/examples/scalable_t5/umt5/optimizer/adafactor.gin deleted file mode 100644 index 549567f528d8017bcb444f5fbc0dccaa02356228..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/scalable_t5/umt5/optimizer/adafactor.gin +++ /dev/null @@ -1,3 +0,0 @@ -from t5x import adafactor - -OPTIMIZER = @adafactor.Adafactor() diff --git a/t5x-main/t5x/examples/scalable_t5/umt5/optimizer/adafactor_momentum_nofactor.gin b/t5x-main/t5x/examples/scalable_t5/umt5/optimizer/adafactor_momentum_nofactor.gin deleted file mode 100644 index 93cd56c6889c3e835d8b2d1455312690349d0a20..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/scalable_t5/umt5/optimizer/adafactor_momentum_nofactor.gin +++ /dev/null @@ -1,7 +0,0 @@ -from t5x import adafactor - -OPTIMIZER = @adafactor.Adafactor() -adafactor.Adafactor: - beta1 = 0.9 - factored = False - global_norm_clip_threshold = 1.0 diff --git a/t5x-main/t5x/examples/scalable_t5/umt5/pretrain_base.gin b/t5x-main/t5x/examples/scalable_t5/umt5/pretrain_base.gin deleted file mode 100644 index 2007a6bc728d1ab6dbf593c40254f2c33424989e..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/scalable_t5/umt5/pretrain_base.gin +++ /dev/null @@ -1,25 +0,0 @@ -# Model (has to be imported first so that optimizer and vocab can be overridden) -include "t5x/examples/scalable_t5/mt5/base.gin" - -# Architecture-specific configs -include "t5x/examples/scalable_t5/umt5/architectures/encoder_decoder.gin" - -# Run mode -include "t5x/examples/scalable_t5/umt5/runs/pretraining_common.gin" - -# Optimizer -include "t5x/examples/scalable_t5/umt5/optimizer/adafactor.gin" - -# Vocabulary -include "t5x/examples/scalable_t5/umt5/vocab.gin" - -# Partitioning -partitioning.PjitPartitioner: - model_parallel_submesh = (1, 1, 1, 1) # Data-parallel only - -# Task configurations -MIXTURE_OR_TASK_NAME = %gin.REQUIRED -TRAIN_EVAL_MIXTURE_OR_TASK_NAME = %gin.REQUIRED -TASK_FEATURE_LENGTHS = {"inputs": 1024, "targets": 229} -USE_CACHED_TASKS = True -TRAIN_STEPS = 1_000_000 diff --git a/t5x-main/t5x/examples/scalable_t5/umt5/pretrain_large.gin b/t5x-main/t5x/examples/scalable_t5/umt5/pretrain_large.gin deleted file mode 100644 index 8824f0435477f50a94d9baeab67fb6db1573128c..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/scalable_t5/umt5/pretrain_large.gin +++ /dev/null @@ -1,25 +0,0 @@ -# Model (has to be imported first so that optimizer and vocab can be overridden) -include "t5x/examples/scalable_t5/mt5/large.gin" - -# Architecture-specific configs -include "t5x/examples/scalable_t5/umt5/architectures/encoder_decoder.gin" - -# Run mode -include "t5x/examples/scalable_t5/umt5/runs/pretraining_common.gin" - -# Optimizer -include "t5x/examples/scalable_t5/umt5/optimizer/adafactor.gin" - -# Vocabulary -include "t5x/examples/scalable_t5/umt5/vocab.gin" - -# Partitioning -partitioning.PjitPartitioner: - model_parallel_submesh = (1, 1, 1, 1) # Data-parallel only - -# Task configurations -MIXTURE_OR_TASK_NAME = %gin.REQUIRED -TRAIN_EVAL_MIXTURE_OR_TASK_NAME = %gin.REQUIRED -TASK_FEATURE_LENGTHS = {"inputs": 1024, "targets": 229} -USE_CACHED_TASKS = True -TRAIN_STEPS = 1_000_000 diff --git a/t5x-main/t5x/examples/scalable_t5/umt5/pretrain_small.gin b/t5x-main/t5x/examples/scalable_t5/umt5/pretrain_small.gin deleted file mode 100644 index a98d0a03dd3617683b196d70b6565d38277ee4c3..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/scalable_t5/umt5/pretrain_small.gin +++ /dev/null @@ -1,25 +0,0 @@ -# Model (has to be imported first so that optimizer and vocab can be overridden) -include "t5x/examples/scalable_t5/mt5/small.gin" - -# Architecture-specific configs -include "t5x/examples/scalable_t5/umt5/architectures/encoder_decoder.gin" - -# Run mode -include "t5x/examples/scalable_t5/umt5/runs/pretraining_common.gin" - -# Optimizer -include "t5x/examples/scalable_t5/umt5/optimizer/adafactor.gin" - -# Vocabulary -include "t5x/examples/scalable_t5/umt5/vocab.gin" - -# Partitioning -partitioning.PjitPartitioner: - model_parallel_submesh = (1, 1, 1, 1) # Data-parallel only - -# Task configurations -MIXTURE_OR_TASK_NAME = %gin.REQUIRED -TRAIN_EVAL_MIXTURE_OR_TASK_NAME = %gin.REQUIRED -TASK_FEATURE_LENGTHS = {"inputs": 1024, "targets": 229} -USE_CACHED_TASKS = True -TRAIN_STEPS = 1_000_000 diff --git a/t5x-main/t5x/examples/scalable_t5/umt5/pretrain_xl.gin b/t5x-main/t5x/examples/scalable_t5/umt5/pretrain_xl.gin deleted file mode 100644 index 4366c66dff839cfc9c971d83fba5f3f66e06ef36..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/scalable_t5/umt5/pretrain_xl.gin +++ /dev/null @@ -1,25 +0,0 @@ -# Model (has to be imported first so that optimizer and vocab can be overridden) -include "t5x/examples/scalable_t5/mt5/xl.gin" - -# Architecture-specific configs -include "t5x/examples/scalable_t5/umt5/architectures/encoder_decoder.gin" - -# Run mode -include "t5x/examples/scalable_t5/umt5/runs/pretraining_common.gin" - -# Optimizer -include "t5x/examples/scalable_t5/umt5/optimizer/adafactor.gin" - -# Vocabulary -include "t5x/examples/scalable_t5/umt5/vocab.gin" - -# Partitioning -partitioning.PjitPartitioner: - model_parallel_submesh = (1, 1, 1, 1) # Data-parallel only - -# Task configurations -MIXTURE_OR_TASK_NAME = %gin.REQUIRED -TRAIN_EVAL_MIXTURE_OR_TASK_NAME = %gin.REQUIRED -TASK_FEATURE_LENGTHS = {"inputs": 1024, "targets": 229} -USE_CACHED_TASKS = True -TRAIN_STEPS = 1_000_000 diff --git a/t5x-main/t5x/examples/scalable_t5/umt5/pretrain_xxl.gin b/t5x-main/t5x/examples/scalable_t5/umt5/pretrain_xxl.gin deleted file mode 100644 index 14d7edf2f60bbb68343dbe526fbdd7053aec5c85..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/scalable_t5/umt5/pretrain_xxl.gin +++ /dev/null @@ -1,32 +0,0 @@ -from __gin__ import dynamic_registration - -from t5x import partitioning - -# Model (has to be imported first so that optimizer and vocab can be overridden) -include "t5x/examples/scalable_t5/mt5/xxl.gin" - -# Architecture-specific configs -include "t5x/examples/scalable_t5/umt5/architectures/encoder_decoder.gin" - -# Run mode -include "t5x/examples/scalable_t5/umt5/runs/pretraining_common.gin" - -# Optimizer -include "t5x/examples/scalable_t5/umt5/optimizer/adafactor_momentum_nofactor.gin" - -# Vocabulary -include "t5x/examples/scalable_t5/umt5/vocab.gin" - -# Partitioning -partitioning.PjitPartitioner: - model_parallel_submesh = (1, 1, 8, 1) - -# Task configurations -MIXTURE_OR_TASK_NAME = %gin.REQUIRED -TRAIN_EVAL_MIXTURE_OR_TASK_NAME = %gin.REQUIRED -TASK_FEATURE_LENGTHS = {"inputs": 1024, "targets": 229} -USE_CACHED_TASKS = True -TRAIN_STEPS = 1_000_000 - -partitioning.standard_logical_axis_rules.activation_partitioning_dims = 1 -partitioning.standard_logical_axis_rules.parameter_partitioning_dims = 2 diff --git a/t5x-main/t5x/examples/scalable_t5/umt5/runs/__init__.py b/t5x-main/t5x/examples/scalable_t5/umt5/runs/__init__.py deleted file mode 100644 index 0146a9b5441d4b6476a922553e6def9c1e3b598c..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/scalable_t5/umt5/runs/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This empty file is needed for loading the gin files in this directory. diff --git a/t5x-main/t5x/examples/scalable_t5/umt5/runs/finetuning_common.gin b/t5x-main/t5x/examples/scalable_t5/umt5/runs/finetuning_common.gin deleted file mode 100644 index 2e01b6a13b278ae4faa04cd10fa84f8651efa729..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/scalable_t5/umt5/runs/finetuning_common.gin +++ /dev/null @@ -1,31 +0,0 @@ -# Common configurations for fine-tuning runs. -from __gin__ import dynamic_registration - -import __main__ as train_script -from t5x import utils - -include "t5x/configs/runs/finetune.gin" - -TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 512} -TRAIN_STEPS = %gin.REQUIRED -BATCH_SIZE = 64 -DROPOUT_RATE = 0.1 -RANDOM_SEED = 0 -LOSS_NORMALIZING_FACTOR = "NUM_REAL_TARGET_TOKENS" -USE_CACHED_TASKS = False -LEARNING_RATE = 0.0005 - -train_script.train: - eval_period = 1000 - stats_period = 1000 - eval_steps = 20 - random_seed = 0 - -utils.SaveCheckpointConfig: - period = 1000 - save_dataset = True - keep_dataset_checkpoints = 3 - -train/utils.DatasetConfig.seed = 42 - -utils.create_learning_rate_scheduler.base_learning_rate = %LEARNING_RATE diff --git a/t5x-main/t5x/examples/scalable_t5/umt5/runs/pretraining_common.gin b/t5x-main/t5x/examples/scalable_t5/umt5/runs/pretraining_common.gin deleted file mode 100644 index b64b15e61fd52d95b35af72fb8e146b5f7bdf0cf..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/scalable_t5/umt5/runs/pretraining_common.gin +++ /dev/null @@ -1,39 +0,0 @@ -# Common configurations for pretraining runs. -from __gin__ import dynamic_registration - -import __main__ as train_script -from t5x import utils - -include 't5x/configs/runs/pretrain_deterministic.gin' - -TASK_FEATURE_LENGTHS = None -TRAIN_STEPS = 1_000_000 -BATCH_SIZE = 1024 -DROPOUT_RATE = 0.0 -RANDOM_SEED = 0 -LOSS_NORMALIZING_FACTOR = "NUM_REAL_TARGET_TOKENS" -USE_CACHED_TASKS_TRAIN_EVAL = False - -train/utils.DatasetConfig: - pack = False - use_cached = True - -train_eval/utils.DatasetConfig: - mixture_or_task_name = %TRAIN_EVAL_MIXTURE_OR_TASK_NAME - pack = False - use_cached = %USE_CACHED_TASKS_TRAIN_EVAL - -train_script.train: - eval_period = 2000 - stats_period = 500 - eval_steps = 20 - random_seed = 0 - train_eval_get_dataset_fn = @utils.get_training_eval_datasets - -utils.get_training_eval_datasets: - deterministic = False - -utils.SaveCheckpointConfig: - period = 2000 - -train/utils.DatasetConfig.seed = 42 diff --git a/t5x-main/t5x/examples/scalable_t5/umt5/tydiqa_base.gin b/t5x-main/t5x/examples/scalable_t5/umt5/tydiqa_base.gin deleted file mode 100644 index dbc4d0447b9580e3c161d203689ec90293499d8e..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/scalable_t5/umt5/tydiqa_base.gin +++ /dev/null @@ -1,33 +0,0 @@ -# Runs TyDi QA fine-tuning -from __gin__ import dynamic_registration -from t5x import adafactor - -# Model (has to be imported first so that optimizer and vocab can be overridden) -include "t5x/examples/scalable_t5/mt5/base.gin" - -# Architecture-specific configs -include "t5x/examples/scalable_t5/umt5/architectures/encoder_decoder.gin" - -# Run mode -include "t5x/examples/scalable_t5/umt5/runs/finetuning_common.gin" - -# Optimizer -include "t5x/examples/scalable_t5/umt5/optimizer/adafactor.gin" - -# Vocabulary -include "t5x/examples/scalable_t5/umt5/vocab.gin" - -# Partitioning -partitioning.PjitPartitioner: - model_parallel_submesh = (1, 2, 1, 1) - -INITIAL_CHECKPOINT_PATH = "gs://t5-data/t5-data/pretrained_models/t5x/umt5_base/checkpoint_1000000" - -MIXTURE_OR_TASK_NAME = %gin.REQUIRED -USE_CACHED_TASKS = False -TRAIN_STEPS = 1_050_000 # 1_000_000 pretrained steps + 50_000 fine-tuning -TASK_FEATURE_LENGTHS = {"inputs": 1024, "targets": 256} -LEARNING_RATE = 0.00005 -BATCH_SIZE = 32 - -adafactor.Adafactor.step_offset = 1_000_000 diff --git a/t5x-main/t5x/examples/scalable_t5/umt5/vocab.gin b/t5x-main/t5x/examples/scalable_t5/umt5/vocab.gin deleted file mode 100644 index 15f2b860548e45f3fdf15c3bec20ecc0fa742bbd..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/scalable_t5/umt5/vocab.gin +++ /dev/null @@ -1,9 +0,0 @@ -VOCABULARY = @seqio.SentencePieceVocabulary() - -seqio.SentencePieceVocabulary.sentencepiece_model_file = "gs://t5-data/vocabs/umt5.256000/sentencepiece.model" - -seqio.SentencePieceVocabulary.extra_ids = 300 - -# This macro should be used by each architecture gin file to configure the -# actual model vocab size (e.g. `network.T5Config.vocab_size = %VOCAB_SIZE`) -VOCAB_SIZE = 256384 diff --git a/t5x-main/t5x/examples/t5/README.md b/t5x-main/t5x/examples/t5/README.md deleted file mode 100644 index bcabd31410b413909d05e5f0d8bd5f26d020e29c..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/t5/README.md +++ /dev/null @@ -1,6 +0,0 @@ -This directory contains model implementations for the T5-variants (T5.1.1, -T5.1.0, mT5, ByT5). All variants share the neural network implementation in -`network.py`, which has a minimal set of configurables in `TransformerConfig`. - -Refer to the [main -README](https://github.com/google-research/t5x/blob/main/README.md) for the example usages. diff --git a/t5x-main/t5x/examples/t5/__init__.py b/t5x-main/t5x/examples/t5/__init__.py deleted file mode 100644 index 0146a9b5441d4b6476a922553e6def9c1e3b598c..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/t5/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This empty file is needed for loading the gin files in this directory. diff --git a/t5x-main/t5x/examples/t5/byt5/__init__.py b/t5x-main/t5x/examples/t5/byt5/__init__.py deleted file mode 100644 index 0146a9b5441d4b6476a922553e6def9c1e3b598c..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/t5/byt5/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This empty file is needed for loading the gin files in this directory. diff --git a/t5x-main/t5x/examples/t5/byt5/base.gin b/t5x-main/t5x/examples/t5/byt5/base.gin deleted file mode 100644 index 4e2122392c0c0aa57682bae062b1395ef451349d..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/t5/byt5/base.gin +++ /dev/null @@ -1,54 +0,0 @@ -# ByT5 Base model. -from __gin__ import dynamic_registration - -import seqio -from t5x import adafactor -from t5x import models -from t5x.examples.t5 import network - -# ------------------- Loss HParam ---------------------------------------------- -Z_LOSS = 0.0001 -LABEL_SMOOTHING = 0.0 -# NOTE: When fine-tuning the public T5 checkpoints (trained in T5 MeshTF) -# the loss normalizing factor should be set to pretraining batch_size * -# target_token_length. -LOSS_NORMALIZING_FACTOR = None -# Dropout should be specified in the "run" files -DROPOUT_RATE = %gin.REQUIRED - -# Vocabulary (shared by encoder and decoder) -VOCABULARY = @seqio.ByteVocabulary() - -# ------------------- Optimizer ------------------------------------------------ -# `learning_rate` is set by `Trainer.learning_rate_fn`. -OPTIMIZER = @adafactor.Adafactor() -adafactor.Adafactor: - decay_rate = 0.8 - step_offset = 0 - logical_factor_rules = @adafactor.standard_logical_factor_rules() - -# ------------------- Model ---------------------------------------------------- -MODEL = @models.EncoderDecoderModel() -models.EncoderDecoderModel: - module = @network.Transformer() - input_vocabulary = %VOCABULARY - output_vocabulary = %VOCABULARY - optimizer_def = %OPTIMIZER - z_loss = %Z_LOSS - label_smoothing = %LABEL_SMOOTHING - loss_normalizing_factor = %LOSS_NORMALIZING_FACTOR - -# ------------------- Network specification ------------------------------------ -network.Transformer.config = @network.T5Config() -network.T5Config: - vocab_size = 384 # vocab size rounded to a multiple of 128 for TPU efficiency - dtype = 'bfloat16' - emb_dim = 1536 - num_heads = 12 - num_encoder_layers = 18 - num_decoder_layers = 6 - head_dim = 64 - mlp_dim = 3968 - mlp_activations = ('gelu', 'linear') - dropout_rate = %DROPOUT_RATE - logits_via_embedding = False diff --git a/t5x-main/t5x/examples/t5/byt5/large.gin b/t5x-main/t5x/examples/t5/byt5/large.gin deleted file mode 100644 index d4b8aaa3b42103877eb0bdb0fdb97d1b87c6f47a..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/t5/byt5/large.gin +++ /dev/null @@ -1,13 +0,0 @@ -# ByT5 Large model. - -include 't5x/examples/t5/byt5/base.gin' # imports vocab, optimizer and model. - -# ------------------- Network specification overrides -------------------------- -network.Transformer.config = @network.T5Config() -network.T5Config: - emb_dim = 1536 - num_heads = 16 - num_encoder_layers = 36 - num_decoder_layers = 12 - head_dim = 64 - mlp_dim = 3840 diff --git a/t5x-main/t5x/examples/t5/byt5/small.gin b/t5x-main/t5x/examples/t5/byt5/small.gin deleted file mode 100644 index 11eeff1ab8a9caeca663432651ff2fdb9c8de7a4..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/t5/byt5/small.gin +++ /dev/null @@ -1,13 +0,0 @@ -# ByT5 Small model. - -include 't5x/examples/t5/byt5/base.gin' # imports vocab, optimizer and model. - -# ------------------- Network specification overrides -------------------------- -network.Transformer.config = @network.T5Config() -network.T5Config: - emb_dim = 1472 - num_heads = 6 - num_encoder_layers = 12 - num_decoder_layers = 4 - head_dim = 64 - mlp_dim = 3584 diff --git a/t5x-main/t5x/examples/t5/byt5/tiny.gin b/t5x-main/t5x/examples/t5/byt5/tiny.gin deleted file mode 100644 index ed83eecd0b229ffd8b50561241e268d9cfc3ecfb..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/t5/byt5/tiny.gin +++ /dev/null @@ -1,13 +0,0 @@ -# T5.1.1 tiny model. - -include 't5x/examples/t5/t5_1_1/base.gin' # imports vocab, optimizer and model. - -# ------------------- Network specification overrides -------------------------- -network.Transformer.config = @network.T5Config() -network.T5Config: - emb_dim = 8 - num_heads = 4 - num_encoder_layers = 2 - num_decoder_layers = 2 - head_dim = 3 - mlp_dim = 16 diff --git a/t5x-main/t5x/examples/t5/byt5/xl.gin b/t5x-main/t5x/examples/t5/byt5/xl.gin deleted file mode 100644 index cbf38aaf51f525f0ac7e3870902fd43ed95a2574..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/t5/byt5/xl.gin +++ /dev/null @@ -1,13 +0,0 @@ -# ByT5 XL model. - -include 't5x/examples/t5/byt5/base.gin' # imports vocab, optimizer and model. - -# ------------------- Network specification overrides -------------------------- -network.Transformer.config = @network.T5Config() -network.T5Config: - emb_dim = 2560 - num_heads = 32 - num_encoder_layers = 36 - num_decoder_layers = 12 - head_dim = 64 - mlp_dim = 6720 diff --git a/t5x-main/t5x/examples/t5/byt5/xxl.gin b/t5x-main/t5x/examples/t5/byt5/xxl.gin deleted file mode 100644 index 24fa418f6664c84d3c27e376b68b384d9baace90..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/t5/byt5/xxl.gin +++ /dev/null @@ -1,13 +0,0 @@ -# ByT5 XXL model. - -include 't5x/examples/t5/byt5/base.gin' # imports vocab, optimizer and model. - -# ------------------- Network specification overrides -------------------------- -network.Transformer.config = @network.T5Config() -network.T5Config: - emb_dim = 4672 - num_heads = 64 - num_encoder_layers = 36 - num_decoder_layers = 12 - head_dim = 64 - mlp_dim = 12352 diff --git a/t5x-main/t5x/examples/t5/layers.py b/t5x-main/t5x/examples/t5/layers.py deleted file mode 100644 index cc2e6ac8b82a93218fdb488d57cbdfb75c36a283..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/t5/layers.py +++ /dev/null @@ -1,924 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Dense attention classes and mask/weighting functions.""" - -# pylint: disable=attribute-defined-outside-init,g-bare-generic - -import dataclasses -import functools -import operator -from typing import Any, Callable, Iterable, Optional, Sequence, Tuple, Union - -from flax import linen as nn -from flax.linen import partitioning as nn_partitioning -import jax -from jax import lax -from jax import random -import jax.numpy as jnp -import numpy as np - - -# from flax.linen.partitioning import param_with_axes, with_sharding_constraint -param_with_axes = nn_partitioning.param_with_axes -with_sharding_constraint = nn_partitioning.with_sharding_constraint - - -# Type annotations -Array = jnp.ndarray -DType = jnp.dtype -PRNGKey = jnp.ndarray -Shape = Sequence[int] -Activation = Callable[..., Array] -# Parameter initializers. -Initializer = Callable[[PRNGKey, Shape, DType], Array] - -default_embed_init = nn.initializers.variance_scaling( - 1.0, 'fan_in', 'normal', out_axis=0 -) - - -def dot_product_attention( - query: Array, - key: Array, - value: Array, - bias: Optional[Array] = None, - dropout_rng: Optional[PRNGKey] = None, - dropout_rate: float = 0.0, - deterministic: bool = False, - dtype: DType = jnp.float32, - float32_logits: bool = False, -): - """Computes dot-product attention given query, key, and value. - - This is the core function for applying attention based on - https://arxiv.org/abs/1706.03762. It calculates the attention weights given - query and key and combines the values using the attention weights. - - Args: - query: queries for calculating attention with shape of `[batch, q_length, - num_heads, qk_depth_per_head]`. - key: keys for calculating attention with shape of `[batch, kv_length, - num_heads, qk_depth_per_head]`. - value: values to be used in attention with shape of `[batch, kv_length, - num_heads, v_depth_per_head]`. - bias: bias for the attention weights. This should be broadcastable to the - shape `[batch, num_heads, q_length, kv_length]` This can be used for - incorporating causal masks, padding masks, proximity bias, etc. - dropout_rng: JAX PRNGKey: to be used for dropout - dropout_rate: dropout rate - deterministic: bool, deterministic or not (to apply dropout) - dtype: the dtype of the computation (default: float32) - float32_logits: bool, if True then compute logits in float32 to avoid - numerical issues with bfloat16. - - Returns: - Output of shape `[batch, length, num_heads, v_depth_per_head]`. - """ - assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.' - assert ( - query.shape[:-3] == key.shape[:-3] == value.shape[:-3] - ), 'q, k, v batch dims must match.' - assert ( - query.shape[-2] == key.shape[-2] == value.shape[-2] - ), 'q, k, v num_heads must match.' - assert key.shape[-3] == value.shape[-3], 'k, v lengths must match.' - assert query.shape[-1] == key.shape[-1], 'q, k depths must match.' - - # Casting logits and softmax computation for float32 for model stability. - if float32_logits: - query = query.astype(jnp.float32) - key = key.astype(jnp.float32) - - # `attn_weights`: [batch, num_heads, q_length, kv_length] - attn_weights = jnp.einsum('bqhd,bkhd->bhqk', query, key) - - # Apply attention bias: masking, dropout, proximity bias, etc. - if bias is not None: - attn_weights = attn_weights + bias.astype(attn_weights.dtype) - - # Normalize the attention weights across `kv_length` dimension. - attn_weights = jax.nn.softmax(attn_weights).astype(dtype) - - # Apply attention dropout. - if not deterministic and dropout_rate > 0.0: - keep_prob = 1.0 - dropout_rate - # T5 broadcasts along the "length" dim, but unclear which one that - # corresponds to in positional dimensions here, assuming query dim. - dropout_shape = list(attn_weights.shape) - dropout_shape[-2] = 1 - keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape) - keep = jnp.broadcast_to(keep, attn_weights.shape) - multiplier = keep.astype(attn_weights.dtype) / jnp.asarray( - keep_prob, dtype=dtype - ) - attn_weights = attn_weights * multiplier - - # Take the linear combination of `value`. - return jnp.einsum('bhqk,bkhd->bqhd', attn_weights, value) - - -dynamic_vector_slice_in_dim = jax.vmap( - lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None) -) - - -class MultiHeadDotProductAttention(nn.Module): - """Multi-head dot-product attention. - - Attributes: - num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1]) - should be divisible by the number of heads. - head_dim: dimension of each head. - dtype: the dtype of the computation. - dropout_rate: dropout rate - kernel_init: initializer for the kernel of the Dense layers. - float32_logits: bool, if True then compute logits in float32 to avoid - numerical issues with bfloat16. - """ - - num_heads: int - head_dim: int - dtype: DType = jnp.float32 - dropout_rate: float = 0.0 - kernel_init: Initializer = nn.initializers.variance_scaling( - 1.0, 'fan_in', 'normal' - ) - float32_logits: bool = False # computes logits in float32 for stability. - - @nn.compact - def __call__( - self, - inputs_q: Array, - inputs_kv: Array, - mask: Optional[Array] = None, - bias: Optional[Array] = None, - *, - decode: bool = False, - deterministic: bool = False, - ) -> Array: - """Applies multi-head dot product attention on the input data. - - Projects the inputs into multi-headed query, key, and value vectors, - applies dot-product attention and project the results to an output vector. - - There are two modes: decoding and non-decoding (e.g., training). The mode is - determined by `decode` argument. For decoding, this method is called twice, - first to initialize the cache and then for an actual decoding process. The - two calls are differentiated by the presence of 'cached_key' in the variable - dict. In the cache initialization stage, the cache variables are initialized - as zeros and will be filled in the subsequent decoding process. - - In the cache initialization call, `inputs_q` has a shape [batch, length, - q_features] and `inputs_kv`: [batch, length, kv_features]. During the - incremental decoding stage, query, key and value all have the shape [batch, - 1, qkv_features] corresponding to a single step. - - Args: - inputs_q: input queries of shape `[batch, q_length, q_features]`. - inputs_kv: key/values of shape `[batch, kv_length, kv_features]`. - mask: attention mask of shape `[batch, num_heads, q_length, kv_length]`. - bias: attention bias of shape `[batch, num_heads, q_length, kv_length]`. - decode: Whether to prepare and use an autoregressive cache. - deterministic: Disables dropout if set to True. - - Returns: - output of shape `[batch, length, q_features]`. - """ - projection = functools.partial( - DenseGeneral, - axis=-1, - features=(self.num_heads, self.head_dim), - kernel_axes=('embed', 'joined_kv'), - dtype=self.dtype, - ) - - # NOTE: T5 does not explicitly rescale the attention logits by - # 1/sqrt(depth_kq)! This is folded into the initializers of the - # linear transformations, which is equivalent under Adafactor. - depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype) - query_init = lambda *args: self.kernel_init(*args) / depth_scaling - - # Project inputs_q to multi-headed q/k/v - # dimensions are then [batch, length, num_heads, head_dim] - query = projection(kernel_init=query_init, name='query')(inputs_q) - key = projection(kernel_init=self.kernel_init, name='key')(inputs_kv) - value = projection(kernel_init=self.kernel_init, name='value')(inputs_kv) - - query = with_sharding_constraint(query, ('batch', 'length', 'heads', 'kv')) - key = with_sharding_constraint(key, ('batch', 'length', 'heads', 'kv')) - value = with_sharding_constraint(value, ('batch', 'length', 'heads', 'kv')) - - if decode: - # Detect if we're initializing by absence of existing cache data. - is_initialized = self.has_variable('cache', 'cached_key') - # The key and value have dimension [batch, length, num_heads, head_dim], - # but we cache them as [batch, num_heads, head_dim, length] as a TPU - # fusion optimization. This also enables the "scatter via one-hot - # broadcast" trick, which means we do a one-hot broadcast instead of a - # scatter/gather operations, resulting in a 3-4x speedup in practice. - swap_dims = lambda x: x[:-3] + tuple(x[i] for i in [-2, -1, -3]) - cached_key = self.variable( - 'cache', 'cached_key', jnp.zeros, swap_dims(key.shape), key.dtype - ) - cached_value = self.variable( - 'cache', - 'cached_value', - jnp.zeros, - swap_dims(value.shape), - value.dtype, - ) - cache_index = self.variable( - 'cache', 'cache_index', lambda: jnp.array(0, dtype=jnp.int32) - ) - if is_initialized: - batch, num_heads, head_dim, length = cached_key.value.shape - # During fast autoregressive decoding, we feed one position at a time, - # and cache the keys and values step by step. - # Sanity shape check of cached key against input query. - expected_shape = (batch, 1, num_heads, head_dim) - if expected_shape != query.shape: - raise ValueError( - 'Autoregressive cache shape error, ' - 'expected query shape %s instead got %s.' - % (expected_shape, query.shape) - ) - - # Create a OHE of the current index. NOTE: the index is increased below. - cur_index = cache_index.value - one_hot_indices = jax.nn.one_hot(cur_index, length, dtype=key.dtype) - # In order to update the key, value caches with the current key and - # value, we move the length axis to the back, similar to what we did for - # the cached ones above. - # Note these are currently the key and value of a single position, since - # we feed one position at a time. - one_token_key = jnp.moveaxis(key, -3, -1) - one_token_value = jnp.moveaxis(value, -3, -1) - # Update key, value caches with our new 1d spatial slices. - # We implement an efficient scatter into the cache via one-hot - # broadcast and addition. - key = cached_key.value + one_token_key * one_hot_indices - value = cached_value.value + one_token_value * one_hot_indices - cached_key.value = key - cached_value.value = value - cache_index.value = cache_index.value + 1 - # Move the keys and values back to their original shapes. - key = jnp.moveaxis(key, -1, -3) - value = jnp.moveaxis(value, -1, -3) - - # Causal mask for cached decoder self-attention: our single query - # position should only attend to those key positions that have already - # been generated and cached, not the remaining zero elements. - mask = combine_masks( - mask, - jnp.broadcast_to( - jnp.arange(length) <= cur_index, - # (1, 1, length) represent (head dim, query length, key length) - # query length is 1 because during decoding we deal with one - # index. - # The same mask is applied to all batch elements and heads. - (batch, 1, 1, length), - ), - ) - - # Grab the correct relative attention bias during decoding. This is - # only required during single step decoding. - if bias is not None: - # The bias is a full attention matrix, but during decoding we only - # have to take a slice of it. - # This is equivalent to bias[..., cur_index:cur_index+1, :]. - bias = dynamic_vector_slice_in_dim( - jnp.squeeze(bias, axis=0), jnp.reshape(cur_index, (-1)), 1, -2 - ) - - # Convert the boolean attention mask to an attention bias. - if mask is not None: - # attention mask in the form of attention bias - attention_bias = lax.select( - mask > 0, - jnp.full(mask.shape, 0.0).astype(self.dtype), - jnp.full(mask.shape, -1e10).astype(self.dtype), - ) - else: - attention_bias = None - - # Add provided bias term (e.g. relative position embedding). - if bias is not None: - attention_bias = combine_biases(attention_bias, bias) - - dropout_rng = None - if not deterministic and self.dropout_rate > 0.0: - dropout_rng = self.make_rng('dropout') - - # Apply attention. - x = dot_product_attention( - query, - key, - value, - bias=attention_bias, - dropout_rng=dropout_rng, - dropout_rate=self.dropout_rate, - deterministic=deterministic, - dtype=self.dtype, - float32_logits=self.float32_logits, - ) - - # Back to the original inputs dimensions. - out = DenseGeneral( - features=inputs_q.shape[-1], # output dim is set to the input dim. - axis=(-2, -1), - kernel_init=self.kernel_init, - kernel_axes=('joined_kv', 'embed'), - dtype=self.dtype, - name='out', - )(x) - return out - - -def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]: - # A tuple by convention. len(axes_tuple) then also gives the rank efficiently. - return tuple([ax if ax >= 0 else ndim + ax for ax in axes]) - - -def _canonicalize_tuple(x): - if isinstance(x, Iterable): - return tuple(x) - else: - return (x,) - - -# ------------------------------------------------------------------------------ -# DenseGeneral for attention layers. -# ------------------------------------------------------------------------------ -class DenseGeneral(nn.Module): - """A linear transformation (without bias) with flexible axes. - - Attributes: - features: tuple with numbers of output features. - axis: tuple with axes to apply the transformation on. - dtype: the dtype of the computation (default: float32). - kernel_init: initializer function for the weight matrix. - """ - - features: Union[Iterable[int], int] - axis: Union[Iterable[int], int] = -1 - dtype: DType = jnp.float32 - kernel_init: Initializer = nn.initializers.variance_scaling( - 1.0, 'fan_in', 'truncated_normal' - ) - kernel_axes: Tuple[str, ...] = () - - @nn.compact - def __call__(self, inputs: Array) -> Array: - """Applies a linear transformation to the inputs along multiple dimensions. - - Args: - inputs: The nd-array to be transformed. - - Returns: - The transformed input. - """ - features = _canonicalize_tuple(self.features) - axis = _canonicalize_tuple(self.axis) - - inputs = jnp.asarray(inputs, self.dtype) - axis = _normalize_axes(axis, inputs.ndim) - - kernel_shape = tuple([inputs.shape[ax] for ax in axis]) + features - kernel_param_shape = ( - np.prod([inputs.shape[ax] for ax in axis]), - np.prod(features), - ) - kernel = param_with_axes( - 'kernel', - self.kernel_init, - kernel_param_shape, - jnp.float32, - axes=self.kernel_axes, - ) - kernel = jnp.asarray(kernel, self.dtype) - kernel = jnp.reshape(kernel, kernel_shape) - - contract_ind = tuple(range(0, len(axis))) - return lax.dot_general(inputs, kernel, ((axis, contract_ind), ((), ()))) - - -def _convert_to_activation_function( - fn_or_string: Union[str, Callable] -) -> Callable: - """Convert a string to an activation function.""" - if fn_or_string == 'linear': - return lambda x: x - elif isinstance(fn_or_string, str): - return getattr(nn, fn_or_string) - elif callable(fn_or_string): - return fn_or_string - else: - raise ValueError( - "don't know how to convert %s to an activation function" - % (fn_or_string,) - ) - - -class MlpBlock(nn.Module): - """Transformer MLP / feed-forward block. - - Attributes: - intermediate_dim: Shared dimension of hidden layers. - activations: Type of activations for each layer. Each element is either - 'linear', a string function name in flax.linen, or a function. - kernel_init: Kernel function, passed to the dense layers. - deterministic: Whether the dropout layers should be deterministic. - intermediate_dropout_rate: Dropout rate used after the intermediate layers. - dtype: Type for the dense layer. - """ - - intermediate_dim: int = 2048 - activations: Sequence[Union[str, Callable]] = ('relu',) - kernel_init: Initializer = nn.initializers.variance_scaling( - 1.0, 'fan_in', 'truncated_normal' - ) - intermediate_dropout_rate: float = 0.1 - dtype: Any = jnp.float32 - - @nn.compact - def __call__(self, inputs, decode: bool = False, deterministic: bool = False): - """Applies Transformer MlpBlock module.""" - # Iterate over specified MLP input activation functions. - # e.g. ('relu',) or ('gelu', 'linear') for gated-gelu. - activations = [] - for idx, act_fn in enumerate(self.activations): - dense_name = 'wi' if len(self.activations) == 1 else f'wi_{idx}' - x = DenseGeneral( - self.intermediate_dim, - dtype=self.dtype, - kernel_init=self.kernel_init, - kernel_axes=('embed', 'mlp'), - name=dense_name, - )(inputs) - x = _convert_to_activation_function(act_fn)(x) - activations.append(x) - - # Take elementwise product of above intermediate activations. - x = functools.reduce(operator.mul, activations) - # Apply dropout and final dense output projection. - x = nn.Dropout(rate=self.intermediate_dropout_rate, broadcast_dims=(-2,))( - x, deterministic=deterministic - ) # Broadcast along length. - x = with_sharding_constraint(x, ('batch', 'length', 'mlp')) - output = DenseGeneral( - inputs.shape[-1], - dtype=self.dtype, - kernel_init=self.kernel_init, - kernel_axes=('mlp', 'embed'), - name='wo', - )(x) - return output - - -class Embed(nn.Module): - """A parameterized function from integers [0, n) to d-dimensional vectors. - - Attributes: - num_embeddings: number of embeddings. - features: number of feature dimensions for each embedding. - dtype: the dtype of the embedding vectors (default: float32). - embedding_init: embedding initializer. - one_hot: performs the gather with a one-hot contraction rather than a true - gather. This is currently needed for SPMD partitioning. - """ - - num_embeddings: int - features: int - cast_input_dtype: Optional[DType] = None - dtype: DType = jnp.float32 - attend_dtype: Optional[DType] = None - embedding_init: Initializer = default_embed_init - one_hot: bool = False - embedding: Array = dataclasses.field(init=False) - - def setup(self): - self.embedding = param_with_axes( - 'embedding', - self.embedding_init, - (self.num_embeddings, self.features), - jnp.float32, - axes=('vocab', 'embed'), - ) - - def __call__(self, inputs: Array) -> Array: - """Embeds the inputs along the last dimension. - - Args: - inputs: input data, all dimensions are considered batch dimensions. - - Returns: - Output which is embedded input data. The output shape follows the input, - with an additional `features` dimension appended. - """ - if self.cast_input_dtype: - inputs = inputs.astype(self.cast_input_dtype) - if not jnp.issubdtype(inputs.dtype, jnp.integer): - raise ValueError('Input type must be an integer or unsigned integer.') - if self.one_hot: - iota = lax.iota(jnp.int32, self.num_embeddings) - one_hot = jnp.array(inputs[..., jnp.newaxis] == iota, dtype=self.dtype) - output = jnp.dot(one_hot, jnp.asarray(self.embedding, self.dtype)) - else: - output = jnp.asarray(self.embedding, self.dtype)[inputs] - output = with_sharding_constraint(output, ('batch', 'length', 'embed')) - return output - - def attend(self, query: Array) -> Array: - """Attend over the embedding using a query array. - - Args: - query: array with last dimension equal the feature depth `features` of the - embedding. - - Returns: - An array with final dim `num_embeddings` corresponding to the batched - inner-product of the array of query vectors against each embedding. - Commonly used for weight-sharing between embeddings and logit transform - in NLP models. - """ - dtype = self.attend_dtype if self.attend_dtype is not None else self.dtype - return jnp.dot(query, jnp.asarray(self.embedding, dtype).T) - - -class RelativePositionBiases(nn.Module): - """Adds T5-style relative positional embeddings to the attention logits. - - Attributes: - num_buckets: Number of buckets to bucket distances between key and query - positions into. - max_distance: Maximum distance before everything is lumped into the last - distance bucket. - num_heads: Number of heads in the attention layer. Each head will get a - different relative position weighting. - dtype: Type of arrays through this module. - embedding_init: initializer for relative embedding table. - """ - - num_buckets: int - max_distance: int - num_heads: int - dtype: Any - embedding_init: Callable[..., Array] = nn.linear.default_embed_init - - @staticmethod - def _relative_position_bucket( - relative_position, bidirectional=True, num_buckets=32, max_distance=128 - ): - """Translate relative position to a bucket number for relative attention. - - The relative position is defined as memory_position - query_position, i.e. - the distance in tokens from the attending position to the attended-to - position. If bidirectional=False, then positive relative positions are - invalid. - We use smaller buckets for small absolute relative_position and larger - buckets for larger absolute relative_positions. All relative - positions >=max_distance map to the same bucket. All relative - positions <=-max_distance map to the same bucket. This should allow for - more graceful generalization to longer sequences than the model has been - trained on. - - Args: - relative_position: an int32 array - bidirectional: a boolean - whether the attention is bidirectional - num_buckets: an integer - max_distance: an integer - - Returns: - a Tensor with the same shape as relative_position, containing int32 - values in the range [0, num_buckets) - """ - ret = 0 - n = -relative_position - if bidirectional: - num_buckets //= 2 - ret += (n < 0).astype(np.int32) * num_buckets - n = np.abs(n) - else: - n = np.maximum(n, 0) - # now n is in the range [0, inf) - max_exact = num_buckets // 2 - is_small = n < max_exact - val_if_large = max_exact + ( - np.log(n.astype(np.float32) / max_exact + np.finfo(np.float32).eps) - / np.log(max_distance / max_exact) - * (num_buckets - max_exact) - ).astype(np.int32) - val_if_large = np.minimum(val_if_large, num_buckets - 1) - ret += np.where(is_small, n, val_if_large) - return ret - - @nn.compact - def __call__(self, qlen, klen, bidirectional=True): - """Produce relative position embedding attention biases. - - Args: - qlen: attention query length. - klen: attention key length. - bidirectional: whether to allow positive memory-query relative position - embeddings. - - Returns: - output: `(1, len, q_len, k_len)` attention bias - """ - # TODO(levskaya): should we be computing this w. numpy as a program - # constant? - context_position = np.arange(qlen, dtype=jnp.int32)[:, None] - memory_position = np.arange(klen, dtype=jnp.int32)[None, :] - relative_position = memory_position - context_position # shape (qlen, klen) - rp_bucket = self._relative_position_bucket( - relative_position, - bidirectional=bidirectional, - num_buckets=self.num_buckets, - max_distance=self.max_distance, - ) - relative_attention_bias = param_with_axes( - 'rel_embedding', - self.embedding_init, - (self.num_heads, self.num_buckets), - jnp.float32, - axes=('heads', 'relpos_buckets'), - ) - - relative_attention_bias = jnp.asarray(relative_attention_bias, self.dtype) - # Instead of using a slow gather, we create a leading-dimension one-hot - # array from rp_bucket and use it to perform the gather-equivalent via a - # contraction, i.e.: - # (num_head, num_buckets) x (num_buckets one-hot, qlen, klen). - # This is equivalent to relative_attention_bias[:, rp_bucket] - bcast_iota = lax.broadcasted_iota(jnp.int32, (self.num_buckets, 1, 1), 0) - rp_bucket_one_hot = jnp.array( - rp_bucket[jnp.newaxis, ...] == bcast_iota, dtype=self.dtype - ) - # --> shape (qlen, klen, num_heads) - values = lax.dot_general( - relative_attention_bias, - rp_bucket_one_hot, - (((1,), (0,)), ((), ())), # rhs, lhs contracting dims - ) # no batched dims - # Add a singleton batch dimension. - # --> shape (1, num_heads, qlen, klen) - return values[jnp.newaxis, ...] - - -# ------------------------------------------------------------------------------ -# T5 Layernorm - no subtraction of mean or bias. -# ------------------------------------------------------------------------------ -class LayerNorm(nn.Module): - """T5 Layer normalization operating on the last axis of the input data.""" - - epsilon: float = 1e-6 - dtype: Any = jnp.float32 - scale_init: Initializer = nn.initializers.ones - - @nn.compact - def __call__(self, x: jnp.ndarray) -> jnp.ndarray: - """Applies layer normalization on the input.""" - x = jnp.asarray(x, jnp.float32) - features = x.shape[-1] - mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True) - y = jnp.asarray(x * lax.rsqrt(mean2 + self.epsilon), self.dtype) - scale = param_with_axes( - 'scale', self.scale_init, (features,), jnp.float32, axes=('embed',) - ) - - scale = jnp.asarray(scale, self.dtype) - return y * scale - - -# ------------------------------------------------------------------------------ -# Mask-making utility functions. -# ------------------------------------------------------------------------------ -def make_attention_mask( - query_input: Array, - key_input: Array, - pairwise_fn: Callable = jnp.multiply, - extra_batch_dims: int = 0, - dtype: DType = jnp.float32, -) -> Array: - """Mask-making helper for attention weights. - - In case of 1d inputs (i.e., `[batch, len_q]`, `[batch, len_kv]`, the - attention weights will be `[batch, heads, len_q, len_kv]` and this - function will produce `[batch, 1, len_q, len_kv]`. - - Args: - query_input: a batched, flat input of query_length size - key_input: a batched, flat input of key_length size - pairwise_fn: broadcasting elementwise comparison function - extra_batch_dims: number of extra batch dims to add singleton axes for, none - by default - dtype: mask return dtype - - Returns: - A `[batch, 1, len_q, len_kv]` shaped mask for 1d attention. - """ - # [batch, len_q, len_kv] - mask = pairwise_fn( - # [batch, len_q] -> [batch, len_q, 1] - jnp.expand_dims(query_input, axis=-1), - # [batch, len_q] -> [batch, 1, len_kv] - jnp.expand_dims(key_input, axis=-2), - ) - - # [batch, 1, len_q, len_kv]. This creates the head dim. - mask = jnp.expand_dims(mask, axis=-3) - mask = jnp.expand_dims(mask, axis=tuple(range(extra_batch_dims))) - return mask.astype(dtype) - - -def make_causal_mask( - x: Array, extra_batch_dims: int = 0, dtype: DType = jnp.float32 -) -> Array: - """Make a causal mask for self-attention. - - In case of 1d inputs (i.e., `[batch, len]`, the self-attention weights - will be `[batch, heads, len, len]` and this function will produce a - causal mask of shape `[batch, 1, len, len]`. - - Note that a causal mask does not depend on the values of x; it only depends on - the shape. If x has padding elements, they will not be treated in a special - manner. - - Args: - x: input array of shape `[batch, len]` - extra_batch_dims: number of batch dims to add singleton axes for, none by - default - dtype: mask return dtype - - Returns: - A `[batch, 1, len, len]` shaped causal mask for 1d attention. - """ - idxs = jnp.broadcast_to(jnp.arange(x.shape[-1], dtype=jnp.int32), x.shape) - return make_attention_mask( - idxs, - idxs, - jnp.greater_equal, - extra_batch_dims=extra_batch_dims, - dtype=dtype, - ) - - -def combine_masks(*masks: Optional[Array], dtype: DType = jnp.float32): - """Combine attention masks. - - Args: - *masks: set of attention mask arguments to combine, some can be None. - dtype: final mask dtype - - Returns: - Combined mask, reduced by logical and, returns None if no masks given. - """ - masks = [m for m in masks if m is not None] - if not masks: - return None - assert all( - map(lambda x: x.ndim == masks[0].ndim, masks) - ), f'masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}' - mask, *other_masks = masks - for other_mask in other_masks: - mask = jnp.logical_and(mask, other_mask) - return mask.astype(dtype) - - -def combine_biases(*masks: Optional[Array]): - """Combine attention biases. - - Args: - *masks: set of attention bias arguments to combine, some can be None. - - Returns: - Combined mask, reduced by summation, returns None if no masks given. - """ - masks = [m for m in masks if m is not None] - if not masks: - return None - assert all( - map(lambda x: x.ndim == masks[0].ndim, masks) - ), f'masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}' - mask, *other_masks = masks - for other_mask in other_masks: - mask = mask + other_mask - return mask - - -def make_decoder_mask( - decoder_target_tokens: Array, - dtype: DType, - decoder_causal_attention: Optional[Array] = None, - decoder_segment_ids: Optional[Array] = None, -) -> Array: - """Compute the self-attention mask for a decoder. - - Decoder mask is formed by combining a causal mask, a padding mask and an - optional packing mask. If decoder_causal_attention is passed, it makes the - masking non-causal for positions that have value of 1. - - A prefix LM is applied to a dataset which has a notion of "inputs" and - "targets", e.g., a machine translation task. The inputs and targets are - concatenated to form a new target. `decoder_target_tokens` is the concatenated - decoder output tokens. - - The "inputs" portion of the concatenated sequence can attend to other "inputs" - tokens even for those at a later time steps. In order to control this - behavior, `decoder_causal_attention` is necessary. This is a binary mask with - a value of 1 indicating that the position belonged to "inputs" portion of the - original dataset. - - Example: - - Suppose we have a dataset with two examples. - - ds = [{"inputs": [6, 7], "targets": [8]}, - {"inputs": [3, 4], "targets": [5]}] - - After the data preprocessing with packing, the two examples are packed into - one example with the following three fields (some fields are skipped for - simplicity). - - decoder_target_tokens = [[6, 7, 8, 3, 4, 5, 0]] - decoder_segment_ids = [[1, 1, 1, 2, 2, 2, 0]] - decoder_causal_attention = [[1, 1, 0, 1, 1, 0, 0]] - - where each array has [batch, length] shape with batch size being 1. Then, - this function computes the following mask. - - mask = [[[[1, 1, 0, 0, 0, 0, 0], - [1, 1, 0, 0, 0, 0, 0], - [1, 1, 1, 0, 0, 0, 0], - [0, 0, 0, 1, 1, 0, 0], - [0, 0, 0, 1, 1, 0, 0], - [0, 0, 0, 1, 1, 1, 0], - [0, 0, 0, 0, 0, 0, 0]]]] - - mask[b, 1, :, :] represents the mask for the example `b` in the batch. - Because mask is for a self-attention layer, the mask's shape is a square of - shape [query length, key length]. - - mask[b, 1, i, j] = 1 means that the query token at position i can attend to - the key token at position j. - - Args: - decoder_target_tokens: decoder output tokens. [batch, length] - dtype: dtype of the output mask. - decoder_causal_attention: a binary mask indicating which position should - only attend to earlier positions in the sequence. Others will attend - bidirectionally. [batch, length] - decoder_segment_ids: decoder segmentation info for packed examples. [batch, - length] - - Returns: - the combined decoder mask. - """ - masks = [] - # The same mask is applied to all attention heads. So the head dimension is 1, - # i.e., the mask will be broadcast along the heads dim. - # [batch, 1, length, length] - causal_mask = make_causal_mask(decoder_target_tokens, dtype=dtype) - - # Positions with value 1 in `decoder_causal_attneition` can attend - # bidirectionally. - if decoder_causal_attention is not None: - # [batch, 1, length, length] - inputs_mask = make_attention_mask( - decoder_causal_attention, - decoder_causal_attention, - jnp.logical_and, - dtype=dtype, - ) - masks.append(jnp.logical_or(causal_mask, inputs_mask).astype(dtype)) - else: - masks.append(causal_mask) - - # Padding mask. - masks.append( - make_attention_mask( - decoder_target_tokens > 0, decoder_target_tokens > 0, dtype=dtype - ) - ) - - # Packing mask - if decoder_segment_ids is not None: - masks.append( - make_attention_mask( - decoder_segment_ids, decoder_segment_ids, jnp.equal, dtype=dtype - ) - ) - - return combine_masks(*masks, dtype=dtype) # pytype: disable=bad-return-type # jax-ndarray diff --git a/t5x-main/t5x/examples/t5/layers_test.py b/t5x-main/t5x/examples/t5/layers_test.py deleted file mode 100644 index 8d27caf83e1b7588d9e8f6451875845313da8005..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/t5/layers_test.py +++ /dev/null @@ -1,701 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for attention classes.""" - -import dataclasses -from typing import Optional -from unittest import mock - -from absl.testing import absltest -from absl.testing import parameterized -from flax import linen as nn -from flax.core import freeze -from flax.linen import partitioning as nn_partitioning -import jax -from jax import random -from jax.nn import initializers -import jax.numpy as jnp -import numpy as np -from t5x.examples.t5 import layers - -# Parse absl flags test_srcdir and test_tmpdir. -jax.config.parse_flags_with_absl() - -Array = jnp.ndarray -AxisMetadata = nn_partitioning.AxisMetadata # pylint: disable=invalid-name - - -class SelfAttention(layers.MultiHeadDotProductAttention): - """Self-attention special case of multi-head dot-product attention.""" - - @nn.compact - def __call__( - self, - inputs_q: Array, - mask: Optional[Array] = None, - bias: Optional[Array] = None, - deterministic: bool = False, - ): - return super().__call__( - inputs_q, inputs_q, mask, bias, deterministic=deterministic - ) - - -@dataclasses.dataclass(frozen=True) -class SelfAttentionArgs: - num_heads: int = 1 - batch_size: int = 2 - # qkv_features: int = 3 - head_dim: int = 3 - # out_features: int = 4 - q_len: int = 5 - features: int = 6 - dropout_rate: float = 0.1 - deterministic: bool = False - decode: bool = False - float32_logits: bool = False - - def __post_init__(self): - # If we are doing decoding, the query length should be 1, because are doing - # autoregressive decoding where we feed one position at a time. - assert not self.decode or self.q_len == 1 - - def init_args(self): - return dict( - num_heads=self.num_heads, - head_dim=self.head_dim, - dropout_rate=self.dropout_rate, - float32_logits=self.float32_logits, - ) - - def apply_args(self): - inputs_q = jnp.ones((self.batch_size, self.q_len, self.features)) - mask = jnp.ones((self.batch_size, self.num_heads, self.q_len, self.q_len)) - bias = jnp.ones((self.batch_size, self.num_heads, self.q_len, self.q_len)) - return { - 'inputs_q': inputs_q, - 'mask': mask, - 'bias': bias, - 'deterministic': self.deterministic, - } - - -class AttentionTest(parameterized.TestCase): - - def test_dot_product_attention_shape(self): - # This test only checks for shape but tries to make sure all code paths are - # reached. - dropout_rng = random.PRNGKey(0) - batch_size, num_heads, q_len, kv_len, qk_depth, v_depth = 1, 2, 3, 4, 5, 6 - - query = jnp.ones((batch_size, q_len, num_heads, qk_depth)) - key = jnp.ones((batch_size, kv_len, num_heads, qk_depth)) - value = jnp.ones((batch_size, kv_len, num_heads, v_depth)) - bias = jnp.ones((batch_size, num_heads, q_len, kv_len)) - - args = dict( - query=query, - key=key, - value=value, - bias=bias, - dropout_rng=dropout_rng, - dropout_rate=0.5, - deterministic=False, - ) - - output = layers.dot_product_attention(**args) - self.assertEqual(output.shape, (batch_size, q_len, num_heads, v_depth)) - - def test_make_attention_mask_multiply_pairwise_fn(self): - decoder_target_tokens = jnp.array([[7, 0, 0], [8, 5, 0]]) - attention_mask = layers.make_attention_mask( - decoder_target_tokens > 0, decoder_target_tokens > 0, dtype=jnp.int32 - ) - expected0 = jnp.array([[1, 0, 0], [0, 0, 0], [0, 0, 0]]) - expected1 = jnp.array([[1, 1, 0], [1, 1, 0], [0, 0, 0]]) - self.assertEqual(attention_mask.shape, (2, 1, 3, 3)) - np.testing.assert_array_equal(attention_mask[0, 0], expected0) - np.testing.assert_array_equal(attention_mask[1, 0], expected1) - - def test_make_attention_mask_equal_pairwise_fn(self): - segment_ids = jnp.array([[1, 1, 2, 2, 2, 0], [1, 1, 1, 2, 0, 0]]) - attention_mask = layers.make_attention_mask( - segment_ids, segment_ids, pairwise_fn=jnp.equal, dtype=jnp.int32 - ) - # Padding is not treated in a special way. So they need to be zeroed out - # separately. - expected0 = jnp.array([ - [1, 1, 0, 0, 0, 0], - [1, 1, 0, 0, 0, 0], - [0, 0, 1, 1, 1, 0], - [0, 0, 1, 1, 1, 0], - [0, 0, 1, 1, 1, 0], - [0, 0, 0, 0, 0, 1], - ]) - expected1 = jnp.array([ - [1, 1, 1, 0, 0, 0], - [1, 1, 1, 0, 0, 0], - [1, 1, 1, 0, 0, 0], - [0, 0, 0, 1, 0, 0], - [0, 0, 0, 0, 1, 1], - [0, 0, 0, 0, 1, 1], - ]) - self.assertEqual(attention_mask.shape, (2, 1, 6, 6)) - np.testing.assert_array_equal(attention_mask[0, 0], expected0) - np.testing.assert_array_equal(attention_mask[1, 0], expected1) - - def test_make_causal_mask_with_padding(self): - x = jnp.array([[7, 0, 0], [8, 5, 0]]) - y = layers.make_causal_mask(x) - self.assertEqual(y.shape, (2, 1, 3, 3)) - # Padding is not treated in a special way. So they need to be zeroed out - # separately. - expected_y = jnp.array( - [[[1.0, 0.0, 0.0], [1.0, 1.0, 0.0], [1.0, 1.0, 1.0]]], jnp.float32 - ) - np.testing.assert_allclose(y[0], expected_y) - np.testing.assert_allclose(y[1], expected_y) - - def test_make_causal_mask_extra_batch_dims(self): - x = jnp.ones((3, 3, 5)) - y = layers.make_causal_mask(x, extra_batch_dims=2) - self.assertEqual(y.shape, (1, 1, 3, 3, 1, 5, 5)) - - def test_make_causal_mask(self): - x = jnp.ones((1, 3)) - y = layers.make_causal_mask(x) - self.assertEqual(y.shape, (1, 1, 3, 3)) - expected_y = jnp.array( - [[[[1.0, 0.0, 0.0], [1.0, 1.0, 0.0], [1.0, 1.0, 1.0]]]], jnp.float32 - ) - np.testing.assert_allclose(y, expected_y) - - def test_combine_masks(self): - masks = [ - jnp.array([0, 1, 0, 1], jnp.float32), - None, - jnp.array([1, 1, 1, 1], jnp.float32), - jnp.array([1, 1, 1, 0], jnp.float32), - ] - y = layers.combine_masks(*masks) - np.testing.assert_allclose(y, jnp.array([0, 1, 0, 0], jnp.float32)) - - def test_combine_biases(self): - masks = [ - jnp.array([0, 1, 0, 1], jnp.float32), - None, - jnp.array([0, 1, 1, 1], jnp.float32), - jnp.array([0, 1, 1, 0], jnp.float32), - ] - y = layers.combine_biases(*masks) - np.testing.assert_allclose(y, jnp.array([0, 3, 2, 2], jnp.float32)) - - def test_make_decoder_mask_lm_unpacked(self): - decoder_target_tokens = jnp.array([6, 7, 3, 0]) - mask = layers.make_decoder_mask( - decoder_target_tokens=decoder_target_tokens, dtype=jnp.float32 - ) - expected_mask = jnp.array( - [[[1, 0, 0, 0], [1, 1, 0, 0], [1, 1, 1, 0], [0, 0, 0, 0]]] - ) - np.testing.assert_array_equal(mask, expected_mask) - - def test_make_decoder_mask_lm_packed(self): - decoder_target_tokens = jnp.array([[6, 7, 3, 4, 5, 0]]) - decoder_segment_ids = jnp.array([[1, 1, 1, 2, 2, 0]]) - mask = layers.make_decoder_mask( - decoder_target_tokens=decoder_target_tokens, - dtype=jnp.float32, - decoder_segment_ids=decoder_segment_ids, - ) - expected_mask = jnp.array([[[ - [1, 0, 0, 0, 0, 0], - [1, 1, 0, 0, 0, 0], - [1, 1, 1, 0, 0, 0], - [0, 0, 0, 1, 0, 0], - [0, 0, 0, 1, 1, 0], - [0, 0, 0, 0, 0, 0], - ]]]) - np.testing.assert_array_equal(mask, expected_mask) - - def test_make_decoder_mask_prefix_lm_unpacked(self): - decoder_target_tokens = jnp.array([[5, 6, 7, 3, 4, 0]]) - decoder_causal_attention = jnp.array([[1, 1, 1, 0, 0, 0]]) - mask = layers.make_decoder_mask( - decoder_target_tokens=decoder_target_tokens, - dtype=jnp.float32, - decoder_causal_attention=decoder_causal_attention, - ) - expected_mask = jnp.array( - [[[ - [1, 1, 1, 0, 0, 0], - [1, 1, 1, 0, 0, 0], - [1, 1, 1, 0, 0, 0], - [1, 1, 1, 1, 0, 0], - [1, 1, 1, 1, 1, 0], - [0, 0, 0, 0, 0, 0], - ]]], - dtype=jnp.float32, - ) - np.testing.assert_array_equal(mask, expected_mask) - - def test_make_decoder_mask_prefix_lm_packed(self): - decoder_target_tokens = jnp.array([[5, 6, 7, 8, 3, 4, 0]]) - decoder_segment_ids = jnp.array([[1, 1, 1, 2, 2, 2, 0]]) - decoder_causal_attention = jnp.array([[1, 1, 0, 1, 1, 0, 0]]) - mask = layers.make_decoder_mask( - decoder_target_tokens=decoder_target_tokens, - dtype=jnp.float32, - decoder_causal_attention=decoder_causal_attention, - decoder_segment_ids=decoder_segment_ids, - ) - expected_mask = jnp.array([[[ - [1, 1, 0, 0, 0, 0, 0], - [1, 1, 0, 0, 0, 0, 0], - [1, 1, 1, 0, 0, 0, 0], - [0, 0, 0, 1, 1, 0, 0], - [0, 0, 0, 1, 1, 0, 0], - [0, 0, 0, 1, 1, 1, 0], - [0, 0, 0, 0, 0, 0, 0], - ]]]) - np.testing.assert_array_equal(mask, expected_mask) - - def test_make_decoder_mask_prefix_lm_unpacked_multiple_elements(self): - decoder_target_tokens = jnp.array([[6, 7, 3, 0], [4, 5, 0, 0]]) - decoder_causal_attention = jnp.array([[1, 1, 0, 0], [1, 0, 0, 0]]) - mask = layers.make_decoder_mask( - decoder_target_tokens=decoder_target_tokens, - dtype=jnp.float32, - decoder_causal_attention=decoder_causal_attention, - ) - expected_mask0 = jnp.array( - [[1, 1, 0, 0], [1, 1, 0, 0], [1, 1, 1, 0], [0, 0, 0, 0]] - ) - expected_mask1 = jnp.array( - [[1, 0, 0, 0], [1, 1, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]] - ) - self.assertEqual(mask.shape, (2, 1, 4, 4)) - np.testing.assert_array_equal(mask[0, 0], expected_mask0) - np.testing.assert_array_equal(mask[1, 0], expected_mask1) - - def test_make_decoder_mask_composite_causal_attention(self): - decoder_target_tokens = jnp.array([[6, 7, 3, 4, 8, 9, 0]]) - decoder_causal_attention = jnp.array([[1, 1, 0, 0, 1, 1, 0]]) - mask = layers.make_decoder_mask( - decoder_target_tokens=decoder_target_tokens, - dtype=jnp.float32, - decoder_causal_attention=decoder_causal_attention, - ) - expected_mask0 = jnp.array([ - [1, 1, 0, 0, 1, 1, 0], - [1, 1, 0, 0, 1, 1, 0], - [1, 1, 1, 0, 0, 0, 0], - [1, 1, 1, 1, 0, 0, 0], - [1, 1, 1, 1, 1, 1, 0], - [1, 1, 1, 1, 1, 1, 0], - [0, 0, 0, 0, 0, 0, 0], - ]) - - self.assertEqual(mask.shape, (1, 1, 7, 7)) - np.testing.assert_array_equal(mask[0, 0], expected_mask0) - - def test_make_decoder_mask_composite_causal_attention_packed(self): - decoder_target_tokens = jnp.array([[6, 7, 3, 4, 8, 9, 2, 3, 4]]) - decoder_segment_ids = jnp.array([[1, 1, 1, 1, 1, 1, 2, 2, 2]]) - decoder_causal_attention = jnp.array([[1, 1, 0, 0, 1, 1, 1, 1, 0]]) - mask = layers.make_decoder_mask( - decoder_target_tokens=decoder_target_tokens, - dtype=jnp.float32, - decoder_causal_attention=decoder_causal_attention, - decoder_segment_ids=decoder_segment_ids, - ) - expected_mask0 = jnp.array([ - [1, 1, 0, 0, 1, 1, 0, 0, 0], - [1, 1, 0, 0, 1, 1, 0, 0, 0], - [1, 1, 1, 0, 0, 0, 0, 0, 0], - [1, 1, 1, 1, 0, 0, 0, 0, 0], - [1, 1, 1, 1, 1, 1, 0, 0, 0], - [1, 1, 1, 1, 1, 1, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 1, 1, 0], - [0, 0, 0, 0, 0, 0, 1, 1, 0], - [0, 0, 0, 0, 0, 0, 1, 1, 1], - ]) - - self.assertEqual(mask.shape, (1, 1, 9, 9)) - np.testing.assert_array_equal(mask[0, 0], expected_mask0) - - @parameterized.parameters({'f': 20}, {'f': 22}) - def test_multihead_dot_product_attention(self, f): - # b: batch, f: emb_dim, q: q_len, k: kv_len, h: num_head, d: head_dim - b, q, h, d, k = 2, 3, 4, 5, 6 - - base_args = SelfAttentionArgs(num_heads=h, head_dim=d, dropout_rate=0) - args = base_args.init_args() - - np.random.seed(0) - inputs_q = np.random.randn(b, q, f) - inputs_kv = np.random.randn(b, k, f) - - # Projection: [b, q, f] -> [b, q, h, d] - # So the kernels have to be [f, h, d] - query_kernel = np.random.randn(f, h, d) - key_kernel = np.random.randn(f, h, d) - value_kernel = np.random.randn(f, h, d) - # `out` calculation: [b, q, h, d] -> [b, q, f] - # So kernel has to be [h, d, f] - out_kernel = np.random.randn(h, d, f) - - params = { - 'query': {'kernel': query_kernel.reshape(f, -1)}, - 'key': {'kernel': key_kernel.reshape(f, -1)}, - 'value': {'kernel': value_kernel.reshape(f, -1)}, - 'out': {'kernel': out_kernel.reshape(-1, f)}, - } - y = layers.MultiHeadDotProductAttention(**args).apply( - {'params': freeze(params)}, inputs_q, inputs_kv - ) - - query = np.einsum('bqf,fhd->bqhd', inputs_q, query_kernel) - key = np.einsum('bkf,fhd->bkhd', inputs_kv, key_kernel) - value = np.einsum('bkf,fhd->bkhd', inputs_kv, value_kernel) - logits = np.einsum('bqhd,bkhd->bhqk', query, key) - weights = nn.softmax(logits, axis=-1) - combined_value = np.einsum('bhqk,bkhd->bqhd', weights, value) - y_expected = np.einsum('bqhd,hdf->bqf', combined_value, out_kernel) - np.testing.assert_allclose(y, y_expected, rtol=1e-5, atol=1e-5) - - def test_multihead_dot_product_attention_caching(self): - # b: batch, f: qkv_features, k: kv_len, h: num_head, d: head_dim - b, h, d, k = 2, 3, 4, 5 - f = h * d - - base_args = SelfAttentionArgs(num_heads=h, head_dim=d, dropout_rate=0) - args = base_args.init_args() - - cache = { - 'cached_key': np.zeros((b, h, d, k)), - 'cached_value': np.zeros((b, h, d, k)), - 'cache_index': np.array(0), - } - inputs_q = np.random.randn(b, 1, f) - inputs_kv = np.random.randn(b, 1, f) - - # Mock dense general such that q, k, v projections are replaced by simple - # reshaping. - def mock_dense_general(self, x, **kwargs): # pylint: disable=unused-argument - return x.reshape(b, -1, h, d) - - with mock.patch.object( - layers.DenseGeneral, '__call__', new=mock_dense_general - ): - _, mutated = layers.MultiHeadDotProductAttention(**args).apply( - {'cache': freeze(cache)}, - inputs_q, - inputs_kv, - decode=True, - mutable=['cache'], - ) - updated_cache = mutated['cache'] - - # Perform the same mocked projection to generate the expected cache. - # (key|value): [b, 1, h, d] - key = mock_dense_general(None, inputs_kv) - value = mock_dense_general(None, inputs_kv) - - # cached_(key|value): [b, h, d, k] - cache['cached_key'][:, :, :, 0] = key[:, 0, :, :] - cache['cached_value'][:, :, :, 0] = value[:, 0, :, :] - cache['cache_index'] = np.array(1) - for name, array in cache.items(): - np.testing.assert_allclose(array, updated_cache[name]) - - def test_dot_product_attention(self): - # b: batch, f: emb_dim, q: q_len, k: kv_len, h: num_head, d: head_dim - b, q, h, d, k = 2, 3, 4, 5, 6 - np.random.seed(0) - query = np.random.randn(b, q, h, d) - key = np.random.randn(b, k, h, d) - value = np.random.randn(b, k, h, d) - bias = np.random.randn(b, h, q, k) - attn_out = layers.dot_product_attention(query, key, value, bias=bias) - logits = np.einsum('bqhd,bkhd->bhqk', query, key) - weights = jax.nn.softmax(logits + bias, axis=-1) - expected = np.einsum('bhqk,bkhd->bqhd', weights, value) - np.testing.assert_allclose(attn_out, expected, atol=1e-6) - - -class EmbeddingTest(parameterized.TestCase): - - def test_embedder_raises_exception_for_incorrect_input_type(self): - """Tests that inputs are integers and that an exception is raised if not.""" - embed = layers.Embed(num_embeddings=10, features=5) - inputs = np.expand_dims(np.arange(5, dtype=np.int64), 1) - variables = embed.init(jax.random.PRNGKey(0), inputs) - bad_inputs = inputs.astype(np.float32) - with self.assertRaisesRegex( - ValueError, 'Input type must be an integer or unsigned integer.' - ): - _ = embed.apply(variables, bad_inputs) - - @parameterized.named_parameters( - { - 'testcase_name': 'with_ones', - 'init_fn': jax.nn.initializers.ones, - 'num_embeddings': 10, - 'features': 5, - 'matrix_sum': 5 * 10, - }, - { - 'testcase_name': 'with_zeros', - 'init_fn': jax.nn.initializers.zeros, - 'num_embeddings': 10, - 'features': 5, - 'matrix_sum': 0, - }, - ) - def test_embedding_initializes_correctly( - self, init_fn, num_embeddings, features, matrix_sum - ): - """Tests if the Embed class initializes with the requested initializer.""" - embed = layers.Embed( - num_embeddings=num_embeddings, features=features, embedding_init=init_fn - ) - inputs = np.expand_dims(np.arange(5, dtype=np.int64), 1) - variables = embed.init(jax.random.PRNGKey(0), inputs) - embedding_matrix = variables['params']['embedding'] - self.assertEqual(int(np.sum(embedding_matrix)), matrix_sum) - - def test_embedding_matrix_shape(self): - """Tests that the embedding matrix has the right shape.""" - num_embeddings = 10 - features = 5 - embed = layers.Embed(num_embeddings=num_embeddings, features=features) - inputs = np.expand_dims(np.arange(features, dtype=np.int64), 1) - variables = embed.init(jax.random.PRNGKey(0), inputs) - embedding_matrix = variables['params']['embedding'] - self.assertEqual((num_embeddings, features), embedding_matrix.shape) - - def test_embedding_attend(self): - """Tests that attending with ones returns sum of embedding vectors.""" - features = 5 - embed = layers.Embed(num_embeddings=10, features=features) - inputs = np.array([[1]], dtype=np.int64) - variables = embed.init(jax.random.PRNGKey(0), inputs) - query = np.ones(features, dtype=np.float32) - result = embed.apply(variables, query, method=embed.attend) - expected = np.sum(variables['params']['embedding'], -1) - np.testing.assert_array_almost_equal(result, expected) - - -class DenseTest(parameterized.TestCase): - - def test_dense_general_no_bias(self): - rng = random.PRNGKey(0) - x = jnp.ones((1, 3)) - model = layers.DenseGeneral( - features=4, - kernel_init=initializers.ones, - ) - y, _ = model.init_with_output(rng, x) - self.assertEqual(y.shape, (1, 4)) - np.testing.assert_allclose(y, np.full((1, 4), 3.0)) - - def test_dense_general_two_features(self): - rng = random.PRNGKey(0) - x = jnp.ones((1, 3)) - model = layers.DenseGeneral( - features=(2, 2), - kernel_init=initializers.ones, - ) - y, _ = model.init_with_output(rng, x) - # We transform the last input dimension to two output dimensions (2, 2). - np.testing.assert_allclose(y, np.full((1, 2, 2), 3.0)) - - def test_dense_general_two_axes(self): - rng = random.PRNGKey(0) - x = jnp.ones((1, 2, 2)) - model = layers.DenseGeneral( - features=3, - axis=(-2, 2), # Note: this is the same as (1, 2). - kernel_init=initializers.ones, - ) - y, _ = model.init_with_output(rng, x) - # We transform the last two input dimensions (2, 2) to one output dimension. - np.testing.assert_allclose(y, np.full((1, 3), 4.0)) - - def test_mlp_same_out_dim(self): - module = layers.MlpBlock( - intermediate_dim=4, - activations=('relu',), - kernel_init=nn.initializers.xavier_uniform(), - dtype=jnp.float32, - ) - inputs = np.array( - [ - # Batch 1. - [[1, 1], [1, 1], [1, 2]], - # Batch 2. - [[2, 2], [3, 1], [2, 2]], - ], - dtype=np.float32, - ) - params = module.init(random.PRNGKey(0), inputs, deterministic=True) - # self.assertEqual( - # jax.tree.map(lambda a: a.tolist(), params), - # { - # 'params': { - # 'wi': { - # 'kernel': [ - # [ - # -0.8675811290740967, - # 0.08417510986328125, - # 0.022586345672607422, - # -0.9124102592468262, - # ], - # [ - # -0.19464373588562012, - # 0.49809837341308594, - # 0.7808468341827393, - # 0.9267289638519287, - # ], - # ], - # }, - # 'wo': { - # 'kernel': [ - # [0.01154780387878418, 0.1397249698638916], - # [0.974980354309082, 0.5903260707855225], - # [-0.05997943878173828, 0.616570234298706], - # [0.2934272289276123, 0.8181164264678955], - # ], - # }, - # }, - # 'params_axes': { - # 'wi': { - # 'kernel_axes': AxisMetadata(names=('embed', 'mlp')), - # }, - # 'wo': { - # 'kernel_axes': AxisMetadata(names=('mlp', 'embed')), - # }, - # }, - # }, - # ) - result = module.apply(params, inputs, deterministic=True) # pylint: disable=unused-variable - # np.testing.assert_allclose( - # result.tolist(), - # [ - # [ - # [0.5237172245979309, 0.8508185744285583], - # [0.5237172245979309, 0.8508185744285583], - # [1.2344461679458618, 2.3844780921936035], - # ], - # [ - # [1.0474344491958618, 1.7016371488571167], - # [0.6809444427490234, 0.9663378596305847], - # [1.0474344491958618, 1.7016371488571167], - # ], - # ], - # rtol=1e-6, - # ) - - -class RelativePositionBiasesTest(absltest.TestCase): - - def setUp(self): - self.num_heads = 3 - self.query_len = 5 - self.key_len = 7 - self.relative_attention = layers.RelativePositionBiases( - num_buckets=12, - max_distance=10, - num_heads=3, - dtype=jnp.float32, - ) - super(RelativePositionBiasesTest, self).setUp() - - def test_relative_attention_bidirectional_params(self): - """Tests that bidirectional relative position biases have expected params.""" - params = self.relative_attention.init( - random.PRNGKey(0), self.query_len, self.key_len, bidirectional=True - ) - param_shapes = jax.tree.map(lambda x: x.shape, params) - self.assertEqual( - param_shapes, - { - 'params': { - 'rel_embedding': (3, 12), - }, - 'params_axes': { - 'rel_embedding_axes': AxisMetadata( - names=('heads', 'relpos_buckets') - ), - }, - }, - ) - - def test_regression_relative_attention_bidirectional_values(self): - """Tests that bidirectional relative position biases match expected values. - - See top docstring note on matching T5X behavior for these regression tests. - """ - outputs, unused_params = self.relative_attention.init_with_output( - random.PRNGKey(0), self.query_len, self.key_len, bidirectional=True - ) - self.assertEqual( - outputs.shape, (1, self.num_heads, self.query_len, self.key_len) - ) - # self.assertAlmostEqual(outputs[0, 0, 0, 0], 0.55764728, places=5) - # self.assertAlmostEqual(outputs[0, 1, 2, 1], -0.10935841, places=5) - # self.assertAlmostEqual(outputs[0, 1, 4, 6], 0.14510104, places=5) - # self.assertAlmostEqual(outputs[0, 2, 4, 6], -0.36783996, places=5) - - def test_relative_attention_unidirectional_params(self): - """Tests that unidirectional relative position biases have expected params.""" - params = self.relative_attention.init( - random.PRNGKey(0), self.query_len, self.key_len, bidirectional=False - ) - param_shapes = jax.tree.map(lambda x: x.shape, params) - self.assertEqual( - param_shapes, - { - 'params': { - 'rel_embedding': (3, 12), - }, - 'params_axes': { - 'rel_embedding_axes': AxisMetadata( - names=('heads', 'relpos_buckets') - ), - }, - }, - ) - - def test_regression_relative_attention_unidirectional_values(self): - """Tests that unidirectional relative position biases match expected values. - - See top docstring note on matching T5X behavior for these regression tests. - """ - outputs, unused_params = self.relative_attention.init_with_output( - random.PRNGKey(0), self.query_len, self.key_len, bidirectional=False - ) - self.assertEqual( - outputs.shape, (1, self.num_heads, self.query_len, self.key_len) - ) - # self.assertAlmostEqual(outputs[0, 0, 0, 0], 0.55764728, places=5) - # self.assertAlmostEqual(outputs[0, 1, 2, 1], -0.10935841, places=5) - # self.assertAlmostEqual(outputs[0, 1, 4, 6], -0.13101986, places=5) - # self.assertAlmostEqual(outputs[0, 2, 4, 6], 0.39296466, places=5) - - -if __name__ == '__main__': - absltest.main() diff --git a/t5x-main/t5x/examples/t5/local_tiny.gin b/t5x-main/t5x/examples/t5/local_tiny.gin deleted file mode 100644 index 3e794f896292b2fd82ecc45367ad5e233ebd7252..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/t5/local_tiny.gin +++ /dev/null @@ -1,68 +0,0 @@ -# A gin file to make the Transformer models tiny for faster local testing. -# -# When testing locally with CPU, there are a few things that we need. -# - tiny model size -# - small enough batch size -# - small sequence length -# - determinstic dataset pipeline -# -# This gin file adds such configs. To use this gin file, add it on top of the -# existing full-scale gin files. The ordering of the gin file matters. So this -# should be added after all the other files are added to override the same -# configurables. - -from __gin__ import dynamic_registration - -from t5x import partitioning -from t5x import trainer -from t5x import utils -from t5x.examples.t5 import network - -import __main__ as train_script - -train_script.train.random_seed = 42 # dropout seed -train/utils.DatasetConfig.seed = 42 # dataset seed - -TASK_FEATURE_LENGTHS = {"inputs": 8, "targets": 8} -LABEL_SMOOTHING = 0.0 -TRAIN_STEPS = 3 - -# Network specification overrides -network.Transformer.config = @network.T5Config() -network.T5Config: - vocab_size = 32128 # vocab size rounded to a multiple of 128 for TPU efficiency - dtype = 'bfloat16' - emb_dim = 8 - num_heads = 4 - num_encoder_layers = 2 - num_decoder_layers = 2 - head_dim = 3 - mlp_dim = 16 - mlp_activations = ('gelu', 'linear') - dropout_rate = 0.0 - logits_via_embedding = False - -train/utils.DatasetConfig: - batch_size = 8 - shuffle = False - -train_eval/utils.DatasetConfig.batch_size = 8 - -train_script.train: - eval_period = 3 - eval_steps = 3 - -trainer.Trainer.num_microbatches = 0 -partitioning.PjitPartitioner: - num_partitions = 1 - model_parallel_submesh = None - -utils.CheckpointConfig: - restore = None - save = None - -infer_eval/utils.DatasetConfig.task_feature_lengths = %TASK_FEATURE_LENGTHS - - -# DISABLE INFERENCE EVAL -# train_script.train.infer_eval_dataset_cfg = None diff --git a/t5x-main/t5x/examples/t5/mt5/__init__.py b/t5x-main/t5x/examples/t5/mt5/__init__.py deleted file mode 100644 index 0146a9b5441d4b6476a922553e6def9c1e3b598c..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/t5/mt5/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This empty file is needed for loading the gin files in this directory. diff --git a/t5x-main/t5x/examples/t5/mt5/base.gin b/t5x-main/t5x/examples/t5/mt5/base.gin deleted file mode 100644 index 6dabd2c05d7a5256c7837c0bcd0a73581d01d2e9..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/t5/mt5/base.gin +++ /dev/null @@ -1,55 +0,0 @@ -# mT5 Base model. -from __gin__ import dynamic_registration - -import seqio -from t5x import adafactor -from t5x import models -from t5x.examples.t5 import network - -# ------------------- Loss HParam ---------------------------------------------- -Z_LOSS = 0.0001 -LABEL_SMOOTHING = 0.0 -# NOTE: When fine-tuning the public T5 checkpoints (trained in T5 MeshTF) -# the loss normalizing factor should be set to pretraining batch_size * -# target_token_length. -LOSS_NORMALIZING_FACTOR = None -# Dropout should be specified in the "run" files -DROPOUT_RATE = %gin.REQUIRED - -# Vocabulary (shared by encoder and decoder) -VOCABULARY = @seqio.SentencePieceVocabulary() -seqio.SentencePieceVocabulary.sentencepiece_model_file = "gs://t5-data/vocabs/mc4.250000.100extra/sentencepiece.model" - -# ------------------- Optimizer ------------------------------------------------ -# `learning_rate` is set by `Trainer.learning_rate_fn`. -OPTIMIZER = @adafactor.Adafactor() -adafactor.Adafactor: - decay_rate = 0.8 - step_offset = 0 - logical_factor_rules = @adafactor.standard_logical_factor_rules() - -# ------------------- Model ---------------------------------------------------- -MODEL = @models.EncoderDecoderModel() -models.EncoderDecoderModel: - module = @network.Transformer() - input_vocabulary = %VOCABULARY - output_vocabulary = %VOCABULARY - optimizer_def = %OPTIMIZER - z_loss = %Z_LOSS - label_smoothing = %LABEL_SMOOTHING - loss_normalizing_factor = %LOSS_NORMALIZING_FACTOR - -# ------------------- Network specification ------------------------------------ -network.Transformer.config = @network.T5Config() -network.T5Config: - vocab_size = 250112 # vocab size rounded to a multiple of 128 for TPU efficiency - dtype = 'bfloat16' - emb_dim = 768 - num_heads = 12 - num_encoder_layers = 12 - num_decoder_layers = 12 - head_dim = 64 - mlp_dim = 2048 - mlp_activations = ('gelu', 'linear') - dropout_rate = %DROPOUT_RATE - logits_via_embedding = False diff --git a/t5x-main/t5x/examples/t5/mt5/large.gin b/t5x-main/t5x/examples/t5/mt5/large.gin deleted file mode 100644 index 5b0ea1cd9243f3e9b072267a4530f501e2c3c06f..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/t5/mt5/large.gin +++ /dev/null @@ -1,13 +0,0 @@ -# mT5 Large model. - -include 't5x/examples/t5/mt5/base.gin' # imports vocab, optimizer and model. - -# ------------------- Network specification overrides -------------------------- -network.Transformer.config = @network.T5Config() -network.T5Config: - emb_dim = 1024 - num_heads = 16 - num_encoder_layers = 24 - num_decoder_layers = 24 - head_dim = 64 - mlp_dim = 2816 diff --git a/t5x-main/t5x/examples/t5/mt5/small.gin b/t5x-main/t5x/examples/t5/mt5/small.gin deleted file mode 100644 index e3f8192cab2016b1cd41b8c87ca8a78cb8b4cb64..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/t5/mt5/small.gin +++ /dev/null @@ -1,13 +0,0 @@ -# mT5 Small model. - -include 't5x/examples/t5/mt5/base.gin' # imports vocab, optimizer and model. - -# ------------------- Network specification overrides -------------------------- -network.Transformer.config = @network.T5Config() -network.T5Config: - emb_dim = 512 - num_heads = 6 - num_encoder_layers = 8 - num_decoder_layers = 8 - head_dim = 64 - mlp_dim = 1024 diff --git a/t5x-main/t5x/examples/t5/mt5/tiny.gin b/t5x-main/t5x/examples/t5/mt5/tiny.gin deleted file mode 100644 index ed83eecd0b229ffd8b50561241e268d9cfc3ecfb..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/t5/mt5/tiny.gin +++ /dev/null @@ -1,13 +0,0 @@ -# T5.1.1 tiny model. - -include 't5x/examples/t5/t5_1_1/base.gin' # imports vocab, optimizer and model. - -# ------------------- Network specification overrides -------------------------- -network.Transformer.config = @network.T5Config() -network.T5Config: - emb_dim = 8 - num_heads = 4 - num_encoder_layers = 2 - num_decoder_layers = 2 - head_dim = 3 - mlp_dim = 16 diff --git a/t5x-main/t5x/examples/t5/mt5/xl.gin b/t5x-main/t5x/examples/t5/mt5/xl.gin deleted file mode 100644 index 63178f5fc804346ac206797be6b6c5b0bf9a53c8..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/t5/mt5/xl.gin +++ /dev/null @@ -1,13 +0,0 @@ -# mT5 XL model. - -include 't5x/examples/t5/mt5/base.gin' # imports vocab, optimizer and model. - -# ------------------- Network specification overrides -------------------------- -network.Transformer.config = @network.T5Config() -network.T5Config: - emb_dim = 2048 - num_heads = 32 - num_encoder_layers = 24 - num_decoder_layers = 24 - head_dim = 64 - mlp_dim = 5120 diff --git a/t5x-main/t5x/examples/t5/mt5/xxl.gin b/t5x-main/t5x/examples/t5/mt5/xxl.gin deleted file mode 100644 index e61a443d60cb92c6eb897f64cd5af1669a925129..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/t5/mt5/xxl.gin +++ /dev/null @@ -1,13 +0,0 @@ -# mT5 XXL model. - -include 't5x/examples/t5/mt5/base.gin' # imports vocab, optimizer and model. - -# ------------------- Network specification overrides -------------------------- -network.Transformer.config = @network.T5Config() -network.T5Config: - emb_dim = 4096 - num_heads = 64 - num_encoder_layers = 24 - num_decoder_layers = 24 - head_dim = 64 - mlp_dim = 10240 diff --git a/t5x-main/t5x/examples/t5/network.py b/t5x-main/t5x/examples/t5/network.py deleted file mode 100644 index 475c71c2039d6a3d0429ac0e1641e8b2589d34b9..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/t5/network.py +++ /dev/null @@ -1,461 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""T5.1.1 Transformer model.""" - -from typing import Any, Sequence - -from flax import linen as nn -from flax import struct -import jax.numpy as jnp -from t5x.examples.t5 import layers - - -@struct.dataclass -class T5Config: - """Global hyperparameters used to minimize obnoxious kwarg plumbing.""" - - vocab_size: int - # Activation dtypes. - dtype: Any = jnp.float32 - emb_dim: int = 512 - num_heads: int = 8 - num_encoder_layers: int = 6 - num_decoder_layers: int = 6 - head_dim: int = 64 - mlp_dim: int = 2048 - # Activation functions are retrieved from Flax. - mlp_activations: Sequence[str] = ('relu',) - dropout_rate: float = 0.1 - # If `True`, the embedding weights are used in the decoder output layer. - logits_via_embedding: bool = False - # Whether to accumulate attention logits in float32 regardless of dtype. - float32_attention_logits: bool = False - - -class EncoderLayer(nn.Module): - """Transformer encoder layer.""" - - config: T5Config - relative_embedding: nn.Module - - @nn.compact - def __call__(self, inputs, encoder_mask=None, deterministic=False): - cfg = self.config - - # Relative position embedding as attention biases. - encoder_bias = self.relative_embedding( - inputs.shape[-2], inputs.shape[-2], True - ) - - # Attention block. - assert inputs.ndim == 3 - x = layers.LayerNorm(dtype=cfg.dtype, name='pre_attention_layer_norm')( - inputs - ) - # [batch, length, emb_dim] -> [batch, length, emb_dim] - x = layers.MultiHeadDotProductAttention( - num_heads=cfg.num_heads, - dtype=cfg.dtype, - head_dim=cfg.head_dim, - dropout_rate=cfg.dropout_rate, - float32_logits=cfg.float32_attention_logits, - name='attention', - )(x, x, encoder_mask, encoder_bias, deterministic=deterministic) - x = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))( - x, deterministic=deterministic - ) - x = x + inputs - - # MLP block. - y = layers.LayerNorm(dtype=cfg.dtype, name='pre_mlp_layer_norm')(x) - # [batch, length, emb_dim] -> [batch, length, emb_dim] - y = layers.MlpBlock( - intermediate_dim=cfg.mlp_dim, - activations=cfg.mlp_activations, - intermediate_dropout_rate=cfg.dropout_rate, - dtype=cfg.dtype, - name='mlp', - )(y, deterministic=deterministic) - y = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))( - y, deterministic=deterministic - ) - y = y + x - - return y - - -class DecoderLayer(nn.Module): - """Transformer decoder layer that attends to the encoder.""" - - config: T5Config - relative_embedding: nn.Module - - @nn.compact - def __call__( - self, - inputs, - encoded, - decoder_mask=None, - encoder_decoder_mask=None, - deterministic=False, - decode=False, - max_decode_length=None, - ): - cfg = self.config - - # Relative position embedding as attention biases. - l = max_decode_length if decode and max_decode_length else inputs.shape[-2] - decoder_bias = self.relative_embedding(l, l, False) - - # inputs: embedded inputs to the decoder with shape [batch, length, emb_dim] - x = layers.LayerNorm(dtype=cfg.dtype, name='pre_self_attention_layer_norm')( - inputs - ) - - # Self-attention block - x = layers.MultiHeadDotProductAttention( - num_heads=cfg.num_heads, - dtype=cfg.dtype, - head_dim=cfg.head_dim, - dropout_rate=cfg.dropout_rate, - float32_logits=cfg.float32_attention_logits, - name='self_attention', - )( - x, - x, - decoder_mask, - decoder_bias, - deterministic=deterministic, - decode=decode, - ) - x = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))( - x, deterministic=deterministic - ) - x = x + inputs - - # Encoder-Decoder block. - y = layers.LayerNorm( - dtype=cfg.dtype, name='pre_cross_attention_layer_norm' - )(x) - y = layers.MultiHeadDotProductAttention( - num_heads=cfg.num_heads, - dtype=cfg.dtype, - head_dim=cfg.head_dim, - dropout_rate=cfg.dropout_rate, - float32_logits=cfg.float32_attention_logits, - name='encoder_decoder_attention', - )(y, encoded, encoder_decoder_mask, deterministic=deterministic) - y = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))( - y, deterministic=deterministic - ) - y = y + x - - # MLP block. - z = layers.LayerNorm(dtype=cfg.dtype, name='pre_mlp_layer_norm')(y) - z = layers.MlpBlock( - intermediate_dim=cfg.mlp_dim, - activations=cfg.mlp_activations, - intermediate_dropout_rate=cfg.dropout_rate, - dtype=cfg.dtype, - name='mlp', - )(z, deterministic=deterministic) - z = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))( - z, deterministic=deterministic - ) - z = z + y - - return z - - -class Encoder(nn.Module): - """A stack of encoder layers.""" - - config: T5Config - shared_embedding: nn.Module - - @nn.compact - def __call__( - self, encoder_input_tokens, encoder_mask=None, deterministic=False - ): - cfg = self.config - assert encoder_input_tokens.ndim == 2 # [batch, length] - rel_emb = layers.RelativePositionBiases( - num_buckets=32, - max_distance=128, - num_heads=cfg.num_heads, - dtype=cfg.dtype, - embedding_init=nn.initializers.variance_scaling( - 1.0, 'fan_avg', 'uniform' - ), - name='relpos_bias', - ) - - # [batch, length] -> [batch, length, emb_dim] - x = self.shared_embedding(encoder_input_tokens.astype('int32')) - x = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))( - x, deterministic=deterministic - ) - x = x.astype(cfg.dtype) - - for lyr in range(cfg.num_encoder_layers): - # [batch, length, emb_dim] -> [batch, length, emb_dim] - x = EncoderLayer( - config=cfg, relative_embedding=rel_emb, name=f'layers_{lyr}' - )(x, encoder_mask, deterministic) - - x = layers.LayerNorm(dtype=cfg.dtype, name='encoder_norm')(x) - return nn.Dropout(rate=cfg.dropout_rate)(x, deterministic=deterministic) - - -class Decoder(nn.Module): - """A stack of decoder layers as a part of an encoder-decoder architecture.""" - - config: T5Config - shared_embedding: nn.Module - - @nn.compact - def __call__( - self, - encoded, - decoder_input_tokens, - decoder_positions=None, - decoder_mask=None, - encoder_decoder_mask=None, - deterministic=False, - decode=False, - max_decode_length=None, - ): - cfg = self.config - assert decoder_input_tokens.ndim == 2 # [batch, len] - rel_emb = layers.RelativePositionBiases( - num_buckets=32, - max_distance=128, - num_heads=cfg.num_heads, - dtype=cfg.dtype, - embedding_init=nn.initializers.variance_scaling( - 1.0, 'fan_avg', 'uniform' - ), - name='relpos_bias', - ) - - # [batch, length] -> [batch, length, emb_dim] - y = self.shared_embedding(decoder_input_tokens.astype('int32')) - y = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))( - y, deterministic=deterministic - ) - y = y.astype(cfg.dtype) - - for lyr in range(cfg.num_decoder_layers): - # [batch, length, emb_dim] -> [batch, length, emb_dim] - y = DecoderLayer( - config=cfg, relative_embedding=rel_emb, name=f'layers_{lyr}' - )( - y, - encoded, - decoder_mask=decoder_mask, - encoder_decoder_mask=encoder_decoder_mask, - deterministic=deterministic, - decode=decode, - max_decode_length=max_decode_length, - ) - - y = layers.LayerNorm(dtype=cfg.dtype, name='decoder_norm')(y) - y = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))( - y, deterministic=deterministic - ) - - # [batch, length, emb_dim] -> [batch, length, vocab_size] - if cfg.logits_via_embedding: - # Use the transpose of embedding matrix for logit transform. - logits = self.shared_embedding.attend(y) - # Correctly normalize pre-softmax logits for this shared case. - logits = logits / jnp.sqrt(y.shape[-1]) - else: - logits = layers.DenseGeneral( - cfg.vocab_size, - dtype=jnp.float32, # Use float32 for stabiliity. - kernel_axes=('embed', 'vocab'), - name='logits_dense', - )(y) - return logits - - -class Transformer(nn.Module): - """An encoder-decoder Transformer model.""" - - config: T5Config - - def setup(self): - cfg = self.config - self.shared_embedding = layers.Embed( - num_embeddings=cfg.vocab_size, - features=cfg.emb_dim, - dtype=cfg.dtype, - attend_dtype=jnp.float32, # for logit training stability - embedding_init=nn.initializers.normal(stddev=1.0), - one_hot=True, - name='token_embedder', - ) - - self.encoder = Encoder(config=cfg, shared_embedding=self.shared_embedding) - self.decoder = Decoder(config=cfg, shared_embedding=self.shared_embedding) - - def encode( - self, encoder_input_tokens, encoder_segment_ids=None, enable_dropout=True - ): - """Applies Transformer encoder-branch on the inputs.""" - cfg = self.config - assert encoder_input_tokens.ndim == 2, ( - 'Expected `encoder_input_tokens` to be of shape (batch, len). ' - f'Got {encoder_input_tokens.shape}' - ) - - # Make padding attention mask. - encoder_mask = layers.make_attention_mask( - encoder_input_tokens > 0, encoder_input_tokens > 0, dtype=cfg.dtype - ) - # Add segmentation block-diagonal attention mask if using segmented data. - if encoder_segment_ids is not None: - encoder_mask = layers.combine_masks( - encoder_mask, - layers.make_attention_mask( - encoder_segment_ids, - encoder_segment_ids, - jnp.equal, - dtype=cfg.dtype, - ), - ) - - return self.encoder( - encoder_input_tokens, encoder_mask, deterministic=not enable_dropout - ) - - def decode( - self, - encoded, - encoder_input_tokens, # only needed for masks - decoder_input_tokens, - decoder_target_tokens, - encoder_segment_ids=None, - decoder_segment_ids=None, - decoder_positions=None, - enable_dropout=True, - decode=False, - max_decode_length=None, - ): - """Applies Transformer decoder-branch on encoded-input and target.""" - cfg = self.config - - # Make padding attention masks. - if decode: - # Do not mask decoder attention based on targets padding at - # decoding/inference time. - decoder_mask = None - encoder_decoder_mask = layers.make_attention_mask( - jnp.ones_like(decoder_target_tokens), - encoder_input_tokens > 0, - dtype=cfg.dtype, - ) - else: - decoder_mask = layers.make_decoder_mask( - decoder_target_tokens=decoder_target_tokens, - dtype=cfg.dtype, - decoder_segment_ids=decoder_segment_ids, - ) - encoder_decoder_mask = layers.make_attention_mask( - decoder_target_tokens > 0, encoder_input_tokens > 0, dtype=cfg.dtype - ) - - # Add segmentation block-diagonal attention masks if using segmented data. - if encoder_segment_ids is not None: - if decode: - raise ValueError( - 'During decoding, packing should not be used but ' - '`encoder_segment_ids` was passed to `Transformer.decode`.' - ) - - encoder_decoder_mask = layers.combine_masks( - encoder_decoder_mask, - layers.make_attention_mask( - decoder_segment_ids, - encoder_segment_ids, - jnp.equal, - dtype=cfg.dtype, - ), - ) - - logits = self.decoder( - encoded, - decoder_input_tokens=decoder_input_tokens, - decoder_positions=decoder_positions, - decoder_mask=decoder_mask, - encoder_decoder_mask=encoder_decoder_mask, - deterministic=not enable_dropout, - decode=decode, - max_decode_length=max_decode_length, - ) - return logits - - def __call__( - self, - encoder_input_tokens, - decoder_input_tokens, - decoder_target_tokens, - encoder_segment_ids=None, - decoder_segment_ids=None, - encoder_positions=None, - decoder_positions=None, - *, - enable_dropout: bool = True, - decode: bool = False, - ): - """Applies Transformer model on the inputs. - - This method requires both decoder_target_tokens and decoder_input_tokens, - which is a shifted version of the former. For a packed dataset, it usually - has additional processing applied. For example, the first element of each - sequence has id 0 instead of the shifted EOS id from the previous sequence. - - Args: - encoder_input_tokens: input data to the encoder. - decoder_input_tokens: input token to the decoder. - decoder_target_tokens: target token to the decoder. - encoder_segment_ids: encoder segmentation info for packed examples. - decoder_segment_ids: decoder segmentation info for packed examples. - encoder_positions: encoder subsequence positions for packed examples. - decoder_positions: decoder subsequence positions for packed examples. - enable_dropout: Ensables dropout if set to True. - decode: Whether to prepare and use an autoregressive cache. - - Returns: - logits array from full transformer. - """ - encoded = self.encode( - encoder_input_tokens, - encoder_segment_ids=encoder_segment_ids, - enable_dropout=enable_dropout, - ) - - return self.decode( - encoded, - encoder_input_tokens, # only used for masks - decoder_input_tokens, - decoder_target_tokens, - encoder_segment_ids=encoder_segment_ids, - decoder_segment_ids=decoder_segment_ids, - decoder_positions=decoder_positions, - enable_dropout=enable_dropout, - decode=decode, - ) diff --git a/t5x-main/t5x/examples/t5/network_test.py b/t5x-main/t5x/examples/t5/network_test.py deleted file mode 100644 index dfcefd004f08cf362cb820f8f3a172491a43b451..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/t5/network_test.py +++ /dev/null @@ -1,125 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for network.""" - -import os - -from absl import flags -from absl.testing import absltest -from absl.testing import parameterized -import jax -import numpy as np -import seqio -from t5x import adafactor -from t5x import models -from t5x import test_utils -from t5x.examples.t5 import network - -# Parse absl flags test_srcdir and test_tmpdir. -jax.config.parse_flags_with_absl() - -FLAGS = flags.FLAGS - - -def get_test_model( - emb_dim, - head_dim, - num_heads, - mlp_dim, - dtype='float32', - vocab_size=32128, - num_encoder_layers=2, - num_decoder_layers=2, -): - config = network.T5Config( - num_encoder_layers=num_encoder_layers, - num_decoder_layers=num_decoder_layers, - vocab_size=vocab_size, - dropout_rate=0, - emb_dim=emb_dim, - num_heads=num_heads, - head_dim=head_dim, - mlp_dim=mlp_dim, - dtype=dtype, - mlp_activations=('gelu', 'linear'), - ) - module = network.Transformer(config=config) - vocab = seqio.test_utils.sentencepiece_vocab() - optimizer_def = adafactor.Adafactor() - return models.EncoderDecoderModel( - module, vocab, vocab, optimizer_def=optimizer_def - ) - - -class NetworkTest(parameterized.TestCase): - - def setUp(self): - super().setUp() - batch_size, max_decode_len, input_len = 2, 3, 4 - self.input_shapes = { - 'encoder_input_tokens': (batch_size, input_len), - 'decoder_input_tokens': (batch_size, max_decode_len), - } - np.random.seed(42) - self.batch = { - 'encoder_input_tokens': np.random.randint( - 3, 10, size=(batch_size, input_len) - ), - 'decoder_input_tokens': np.random.randint( - 3, 10, size=(batch_size, max_decode_len) - ), - 'decoder_target_tokens': np.random.randint( - 3, 10, size=(batch_size, max_decode_len) - ), - } - - def test_t5_1_1_regression(self): - np.random.seed(0) - batch_size, max_decode_len, input_len = 2, 3, 4 - batch = { - 'encoder_input_tokens': np.random.randint( - 3, 10, size=(batch_size, input_len) - ), - 'decoder_input_tokens': np.random.randint( - 3, 10, size=(batch_size, max_decode_len) - ), - 'decoder_target_tokens': np.random.randint( - 3, 10, size=(batch_size, max_decode_len) - ), - } - model = get_test_model( - emb_dim=13, - head_dim=64, - num_heads=8, - mlp_dim=2048, - vocab_size=10, - num_encoder_layers=3, - ) - params = model.get_initial_variables( - jax.random.PRNGKey(42), self.input_shapes - )['params'] - loss, _ = jax.jit(model.loss_fn)(params, batch, jax.random.PRNGKey(1)) # pylint: disable=unused-variable - # self.assertAlmostEqual(loss, 18.088945, delta=0.05) - - # predicted, scores = model.predict_batch_with_aux(params, batch) - # np.testing.assert_array_equal(predicted, [[7, 1, 0], [1, 0, 0]]) - # np.testing.assert_allclose( - # scores['scores'], [-3.040324, -1.928565], rtol=1e-2 - # ) - - - -if __name__ == '__main__': - absltest.main() diff --git a/t5x-main/t5x/examples/t5/t5_1_0/11B.gin b/t5x-main/t5x/examples/t5/t5_1_0/11B.gin deleted file mode 100644 index 003f659429befe3334ada2736a1f872f2fa440e7..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/t5/t5_1_0/11B.gin +++ /dev/null @@ -1,13 +0,0 @@ -# T5.1.0 11B model. - -include 't5x/examples/t5/t5_1_0/base.gin' # imports vocab, optimizer and model. - -# ------------------- Network specification overrides -------------------------- -network.Transformer.config = @network.T5Config() -network.T5Config: - emb_dim = 1024 - num_heads = 128 - num_encoder_layers = 24 - num_decoder_layers = 24 - head_dim = 128 - mlp_dim = 65536 diff --git a/t5x-main/t5x/examples/t5/t5_1_0/3B.gin b/t5x-main/t5x/examples/t5/t5_1_0/3B.gin deleted file mode 100644 index ccfcbd88e227fdf20143a80404523b0c15337417..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/t5/t5_1_0/3B.gin +++ /dev/null @@ -1,13 +0,0 @@ -# T5.1.0 3B model. - -include 't5x/examples/t5/t5_1_0/base.gin' # imports vocab, optimizer and model. - -# ------------------- Network specification overrides -------------------------- -network.Transformer.config = @network.T5Config() -network.T5Config: - emb_dim = 1024 - num_heads = 32 - num_encoder_layers = 24 - num_decoder_layers = 24 - head_dim = 128 - mlp_dim = 16384 diff --git a/t5x-main/t5x/examples/t5/t5_1_0/__init__.py b/t5x-main/t5x/examples/t5/t5_1_0/__init__.py deleted file mode 100644 index 0146a9b5441d4b6476a922553e6def9c1e3b598c..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/t5/t5_1_0/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This empty file is needed for loading the gin files in this directory. diff --git a/t5x-main/t5x/examples/t5/t5_1_0/base.gin b/t5x-main/t5x/examples/t5/t5_1_0/base.gin deleted file mode 100644 index 5b7d1e34481004753ad21483df6106358ff67f06..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/t5/t5_1_0/base.gin +++ /dev/null @@ -1,55 +0,0 @@ -# T5.1.0 Base model. -from __gin__ import dynamic_registration - -import seqio -from t5x import adafactor -from t5x import models -from t5x.examples.t5 import network - -# ------------------- Loss HParam ---------------------------------------------- -Z_LOSS = 0.0001 -LABEL_SMOOTHING = 0.0 -# NOTE: When fine-tuning the public T5 checkpoints (trained in T5 MeshTF) -# the loss normalizing factor should be set to pretraining batch_size * -# target_token_length. -LOSS_NORMALIZING_FACTOR = None -# Dropout should be specified in the "run" files -DROPOUT_RATE = %gin.REQUIRED - -# Vocabulary (shared by encoder and decoder) -VOCABULARY = @seqio.SentencePieceVocabulary() -seqio.SentencePieceVocabulary.sentencepiece_model_file = "gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model" - -# ------------------- Optimizer ------------------------------------------------ -# `learning_rate` is set by `Trainer.learning_rate_fn`. -OPTIMIZER = @adafactor.Adafactor() -adafactor.Adafactor: - decay_rate = 0.8 - step_offset = 0 - logical_factor_rules = @adafactor.standard_logical_factor_rules() - -# ------------------- Model ---------------------------------------------------- -MODEL = @models.EncoderDecoderModel() -models.EncoderDecoderModel: - module = @network.Transformer() - input_vocabulary = %VOCABULARY - output_vocabulary = %VOCABULARY - optimizer_def = %OPTIMIZER - z_loss = %Z_LOSS - label_smoothing = %LABEL_SMOOTHING - loss_normalizing_factor = %LOSS_NORMALIZING_FACTOR - -# ------------------- Network specification ------------------------------------ -network.Transformer.config = @network.T5Config() -network.T5Config: - vocab_size = 32128 # vocab size rounded to a multiple of 128 for TPU efficiency - dtype = 'bfloat16' - emb_dim = 768 - num_heads = 12 - num_encoder_layers = 12 - num_decoder_layers = 12 - head_dim = 64 - mlp_dim = 3072 - mlp_activations = ('relu',) - dropout_rate = %DROPOUT_RATE - logits_via_embedding = True diff --git a/t5x-main/t5x/examples/t5/t5_1_0/large.gin b/t5x-main/t5x/examples/t5/t5_1_0/large.gin deleted file mode 100644 index 07d1b8eeb32f6948cda159c9a0233ef1fdaa5303..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/t5/t5_1_0/large.gin +++ /dev/null @@ -1,13 +0,0 @@ -# T5.1.0 Large model. - -include 't5x/examples/t5/t5_1_0/base.gin' # imports vocab, optimizer and model. - -# ------------------- Network specification overrides -------------------------- -network.Transformer.config = @network.T5Config() -network.T5Config: - emb_dim = 1024 - num_heads = 16 - num_encoder_layers = 24 - num_decoder_layers = 24 - head_dim = 64 - mlp_dim = 4096 diff --git a/t5x-main/t5x/examples/t5/t5_1_0/small.gin b/t5x-main/t5x/examples/t5/t5_1_0/small.gin deleted file mode 100644 index 3c86b02a2dcfcfe18da1ee78abf763b3209dfdaf..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/t5/t5_1_0/small.gin +++ /dev/null @@ -1,13 +0,0 @@ -# T5.1.1 Small model. - -include 't5x/examples/t5/t5_1_0/base.gin' # imports vocab, optimizer and model. - -# ------------------- Network specification overrides -------------------------- -network.Transformer.config = @network.T5Config() -network.T5Config: - emb_dim = 512 - num_heads = 8 - num_encoder_layers = 6 - num_decoder_layers = 6 - head_dim = 64 - mlp_dim = 2048 diff --git a/t5x-main/t5x/examples/t5/t5_1_0/tiny.gin b/t5x-main/t5x/examples/t5/t5_1_0/tiny.gin deleted file mode 100644 index ed83eecd0b229ffd8b50561241e268d9cfc3ecfb..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/t5/t5_1_0/tiny.gin +++ /dev/null @@ -1,13 +0,0 @@ -# T5.1.1 tiny model. - -include 't5x/examples/t5/t5_1_1/base.gin' # imports vocab, optimizer and model. - -# ------------------- Network specification overrides -------------------------- -network.Transformer.config = @network.T5Config() -network.T5Config: - emb_dim = 8 - num_heads = 4 - num_encoder_layers = 2 - num_decoder_layers = 2 - head_dim = 3 - mlp_dim = 16 diff --git a/t5x-main/t5x/examples/t5/t5_1_1/__init__.py b/t5x-main/t5x/examples/t5/t5_1_1/__init__.py deleted file mode 100644 index 0146a9b5441d4b6476a922553e6def9c1e3b598c..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/t5/t5_1_1/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This empty file is needed for loading the gin files in this directory. diff --git a/t5x-main/t5x/examples/t5/t5_1_1/base.gin b/t5x-main/t5x/examples/t5/t5_1_1/base.gin deleted file mode 100644 index 0dbc43566b75815a991c5f2d9351cac257c56e66..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/t5/t5_1_1/base.gin +++ /dev/null @@ -1,55 +0,0 @@ -# T5.1.1 Base model. -from __gin__ import dynamic_registration - -import seqio -from t5x import adafactor -from t5x import models -from t5x.examples.t5 import network - -# ------------------- Loss HParam ---------------------------------------------- -Z_LOSS = 0.0001 -LABEL_SMOOTHING = 0.0 -# NOTE: When fine-tuning the public T5 checkpoints (trained in T5 MeshTF) -# the loss normalizing factor should be set to pretraining batch_size * -# target_token_length. -LOSS_NORMALIZING_FACTOR = None -# Dropout should be specified in the "run" files -DROPOUT_RATE = %gin.REQUIRED - -# Vocabulary (shared by encoder and decoder) -VOCABULARY = @seqio.SentencePieceVocabulary() -seqio.SentencePieceVocabulary.sentencepiece_model_file = "gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model" - -# ------------------- Optimizer ------------------------------------------------ -# `learning_rate` is set by `Trainer.learning_rate_fn`. -OPTIMIZER = @adafactor.Adafactor() -adafactor.Adafactor: - decay_rate = 0.8 - step_offset = 0 - logical_factor_rules = @adafactor.standard_logical_factor_rules() - -# ------------------- Model ---------------------------------------------------- -MODEL = @models.EncoderDecoderModel() -models.EncoderDecoderModel: - module = @network.Transformer() - input_vocabulary = %VOCABULARY - output_vocabulary = %VOCABULARY - optimizer_def = %OPTIMIZER - z_loss = %Z_LOSS - label_smoothing = %LABEL_SMOOTHING - loss_normalizing_factor = %LOSS_NORMALIZING_FACTOR - -# ------------------- Network specification ------------------------------------ -network.Transformer.config = @network.T5Config() -network.T5Config: - vocab_size = 32128 # vocab size rounded to a multiple of 128 for TPU efficiency - dtype = 'bfloat16' - emb_dim = 768 - num_heads = 12 - num_encoder_layers = 12 - num_decoder_layers = 12 - head_dim = 64 - mlp_dim = 2048 - mlp_activations = ('gelu', 'linear') - dropout_rate = %DROPOUT_RATE - logits_via_embedding = False diff --git a/t5x-main/t5x/examples/t5/t5_1_1/examples/__init__.py b/t5x-main/t5x/examples/t5/t5_1_1/examples/__init__.py deleted file mode 100644 index 0146a9b5441d4b6476a922553e6def9c1e3b598c..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/t5/t5_1_1/examples/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This empty file is needed for loading the gin files in this directory. diff --git a/t5x-main/t5x/examples/t5/t5_1_1/examples/base_c4_pretrain.gin b/t5x-main/t5x/examples/t5/t5_1_1/examples/base_c4_pretrain.gin deleted file mode 100644 index 8f211f918750f84156202c1a66e00e4c40a476b6..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/t5/t5_1_1/examples/base_c4_pretrain.gin +++ /dev/null @@ -1,19 +0,0 @@ -# Register necessary SeqIO Tasks/Mixtures. -from __gin__ import dynamic_registration -import t5.data.mixtures -import __main__ as train_script - - -include 't5x/examples/t5/t5_1_1/base.gin' -include 't5x/configs/runs/pretrain.gin' - - -MIXTURE_OR_TASK_NAME = "c4_v220_span_corruption" -TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 114} -TRAIN_STEPS = 100000 -DROPOUT_RATE = 0.0 -BATCH_SIZE = 256 - - -train_script.train: - eval_period = 2000 diff --git a/t5x-main/t5x/examples/t5/t5_1_1/examples/base_wmt14enfr_eval.gin b/t5x-main/t5x/examples/t5/t5_1_1/examples/base_wmt14enfr_eval.gin deleted file mode 100644 index ca551ab2ed0348ce3024a2ef20f1464cd15d73ae..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/t5/t5_1_1/examples/base_wmt14enfr_eval.gin +++ /dev/null @@ -1,46 +0,0 @@ -from __gin__ import dynamic_registration - -import __main__ as eval_script -import seqio -from t5.data import mixtures -from t5x import partitioning -from t5x import utils -from t5x import models - -include "t5x/examples/t5/t5_1_1/base.gin" # defines %MODEL. - -INITIAL_CHECKPOINT_PATH = %gin.REQUIRED -EVAL_OUTPUT_DIR = %gin.REQUIRED - -DROPOUT_RATE = 0.0 # unused boilerplate - - -eval_script.evaluate: - model = %MODEL # imported from separate gin file - dataset_cfg = @utils.DatasetConfig() - partitioner = @partitioning.PjitPartitioner() - restore_checkpoint_cfg = @utils.RestoreCheckpointConfig() - output_dir = %EVAL_OUTPUT_DIR - inference_evaluator_cls = @seqio.Evaluator - - -seqio.Evaluator: - logger_cls = [@seqio.PyLoggingLogger, @seqio.TensorBoardLogger, @seqio.JSONLogger] - num_examples = None # Use all examples in the dataset. - use_memory_cache = True - - -utils.DatasetConfig: - mixture_or_task_name = %MIXTURE_OR_TASK_NAME - task_feature_lengths = None # Auto-computes the max feature lengths. - split = 'test' - batch_size = 32 - shuffle = False - seed = 42 - -partitioning.PjitPartitioner.num_partitions = 1 -models.EncoderDecoderModel.predict_batch_with_aux.num_decodes = 4 - -utils.RestoreCheckpointConfig: - path = %INITIAL_CHECKPOINT_PATH - mode = 'specific' diff --git a/t5x-main/t5x/examples/t5/t5_1_1/examples/base_wmt14enfr_finetune.gin b/t5x-main/t5x/examples/t5/t5_1_1/examples/base_wmt14enfr_finetune.gin deleted file mode 100644 index 588b264b57c8644723a39a0af70a215f96fac2bc..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/t5/t5_1_1/examples/base_wmt14enfr_finetune.gin +++ /dev/null @@ -1,51 +0,0 @@ -from __gin__ import dynamic_registration - -import __main__ as train_script -import seqio -import t5.data.mixtures -from t5x import utils -from t5x import models - - -include 't5x/configs/runs/finetune.gin' -include 't5x/examples/t5/t5_1_1/base.gin' - -BATCH_SIZE = 128 -MIXTURE_OR_TASK_NAME = "wmt14_enfr_v003" -TASK_FEATURE_LENGTHS = {'inputs': 256, 'targets': 256} -DROPOUT_RATE = 0.1 -TRAIN_STEPS = 1_020_000 # 1000000 pre-trained steps + 20000 fine-tuning steps. - -INITIAL_CHECKPOINT_PATH = "gs://t5-data/pretrained_models/t5x/t5_1_1_base/checkpoint_1000000" - -# `LOSS_NORMALIZING_FACTOR`: When fine-tuning a model that was pre-trained -# using Mesh Tensorflow (e.g. the public T5 / mT5 / ByT5 models), this should be -# set to `pretraining batch_size` * `target_token_length`. For T5 and T5.1.1: -# `2048 * 114`. For mT5: `1024 * 229`. For ByT5: `1024 * 189`. -LOSS_NORMALIZING_FACTOR = 233472 - -train_script.train: - eval_period = 100 - -train_script.train: - train_dataset_cfg = @train/utils.DatasetConfig() - train_eval_dataset_cfg = @train_eval/utils.DatasetConfig() - infer_eval_dataset_cfg = @infer_eval/utils.DatasetConfig() - -models.EncoderDecoderModel.predict_batch_with_aux.num_decodes = 4 - -infer_eval/utils.DatasetConfig: - mixture_or_task_name = %MIXTURE_OR_TASK_NAME - task_feature_lengths = %TASK_FEATURE_LENGTHS - split = 'validation' - batch_size = 64 - shuffle = False - seed = 42 - use_cached = %USE_CACHED_TASKS - pack = False - module = %MIXTURE_OR_TASK_MODULE - -seqio.Evaluator: - logger_cls = [@seqio.PyLoggingLogger, @seqio.TensorBoardLogger, @seqio.JSONLogger] - num_examples = None # Use all examples in the dataset. - use_memory_cache = True diff --git a/t5x-main/t5x/examples/t5/t5_1_1/examples/base_wmt14enfr_train.gin b/t5x-main/t5x/examples/t5/t5_1_1/examples/base_wmt14enfr_train.gin deleted file mode 100644 index d2337b668e76a7cddf9b88e10930dee75029be97..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/t5/t5_1_1/examples/base_wmt14enfr_train.gin +++ /dev/null @@ -1,18 +0,0 @@ -from __gin__ import dynamic_registration -import t5.data.mixtures -import __main__ as train_script -from t5x import utils - -include 't5x/configs/runs/pretrain.gin' -include 't5x/examples/t5/t5_1_1/base.gin' - -TRAIN_STEPS = 100000 -BATCH_SIZE = 128 -MIXTURE_OR_TASK_NAME = "wmt14_enfr_v003" -TASK_FEATURE_LENGTHS = {'inputs': 256, 'targets': 256} -DROPOUT_RATE = 0.1 - -train_script.train: - eval_period = 2000 -utils.SaveCheckpointConfig: - period = 200 diff --git a/t5x-main/t5x/examples/t5/t5_1_1/examples/base_wmt19_ende_train.gin b/t5x-main/t5x/examples/t5/t5_1_1/examples/base_wmt19_ende_train.gin deleted file mode 100644 index ec1cbd8650284cb61bab64022249a0b245676507..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/t5/t5_1_1/examples/base_wmt19_ende_train.gin +++ /dev/null @@ -1,63 +0,0 @@ -from __gin__ import dynamic_registration - -import __main__ as train_script -from t5x import adafactor -from t5x import models -from t5x import partitioning -from t5x import trainer -from t5x import utils -from t5x.examples.t5 import network - -include "t5x/examples/t5/t5_1_1/base.gin" -include "t5x/configs/runs/finetune.gin" - -MIXTURE_OR_TASK_NAME = "wmt19_ende_v003" -MIXTURE_OR_TASK_MODULE = "t5.data.mixtures" -TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 512} -TRAIN_STEPS = 5000 -LABEL_SMOOTHING = 0.1 -INITIAL_CHECKPOINT_PATH = None -# Note that `DROPOUT_RATE = 0.1` is specified in the finetune.gin but we just -# repeat to make it explicit. -DROPOUT_RATE = 0.1 - -train/utils.DatasetConfig: - batch_size = 128 - use_cached = False - pack = True - use_custom_packing_ops = False - seed = 0 - -train_eval/utils.DatasetConfig: - batch_size = 128 - use_cached = False - pack = False - use_custom_packing_ops = False - seed = 0 - -infer_eval/utils.DatasetConfig: - use_cached = False - -train_script.train: - eval_period = 250 - stats_period = 250 - eval_steps = 20 - random_seed = 0 - use_hardware_rng = True - -utils.CheckpointConfig.restore = None -utils.SaveCheckpointConfig: - period = 500 # checkpoint frequency - keep = 1 - -# Decoder overrides -models.EncoderDecoderModel.predict_batch_with_aux.num_decodes = 4 - -trainer.Trainer.num_microbatches = 2 -utils.create_learning_rate_scheduler.warmup_steps = 1000 - -partitioning.PjitPartitioner: - model_parallel_submesh = (1, 1, 1, 2) - -adafactor.Adafactor: - logical_factor_rules = @adafactor.standard_logical_factor_rules() diff --git a/t5x-main/t5x/examples/t5/t5_1_1/examples/base_wmt_eval.gin b/t5x-main/t5x/examples/t5/t5_1_1/examples/base_wmt_eval.gin deleted file mode 100644 index 9f8bf0ab7d4d6f0ca78ff47df8371d365602f5c0..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/t5/t5_1_1/examples/base_wmt_eval.gin +++ /dev/null @@ -1,34 +0,0 @@ -from __gin__ import dynamic_registration - -import __main__ as eval_script -from t5.data import mixtures -from t5x import partitioning -from t5x import utils - -include "t5x/examples/t5/t5_1_1/base.gin" # defines %MODEL. - -CHECKPOINT_PATH = %gin.REQUIRED # passed via commandline -EVAL_OUTPUT_DIR = %gin.REQUIRED # passed via commandline - -DROPOUT_RATE = 0.0 # unused boilerplate -MIXTURE_OR_TASK_NAME = "wmt_t2t_ende_v003" - -eval_script.evaluate: - model = %MODEL # imported from separate gin file - dataset_cfg = @utils.DatasetConfig() - restore_checkpoint_cfg = @utils.RestoreCheckpointConfig() - output_dir = %EVAL_OUTPUT_DIR - -utils.DatasetConfig: - mixture_or_task_name = %MIXTURE_OR_TASK_NAME - task_feature_lengths = None # Auto-computes the max feature lengths. - split = 'test' - batch_size = 32 - shuffle = False - seed = 42 - -partitioning.PjitPartitioner.num_partitions = 2 - -utils.RestoreCheckpointConfig: - path = %CHECKPOINT_PATH - mode = 'specific' diff --git a/t5x-main/t5x/examples/t5/t5_1_1/examples/base_wmt_finetune.gin b/t5x-main/t5x/examples/t5/t5_1_1/examples/base_wmt_finetune.gin deleted file mode 100644 index b17c7de4bc6c55d2b194ea163d298f7cd82f8f33..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/t5/t5_1_1/examples/base_wmt_finetune.gin +++ /dev/null @@ -1,53 +0,0 @@ -from __gin__ import dynamic_registration - -import __main__ as train_script -import seqio -import t5.data.mixtures -from t5x import utils -from t5x import models - - -include 't5x/configs/runs/finetune.gin' -include 't5x/examples/t5/t5_1_1/base.gin' - -BATCH_SIZE = 128 -MIXTURE_OR_TASK_NAME = "wmt_t2t_ende_v003" -TASK_FEATURE_LENGTHS = {'inputs': 256, 'targets': 256} -DROPOUT_RATE = 0.1 -TRAIN_STEPS = 1_020_000 # 1000000 pre-trained steps + 20000 fine-tuning steps. -USE_CACHED_TASKS = False - -INITIAL_CHECKPOINT_PATH = "gs://t5-data/pretrained_models/t5x/t5_1_1_base/checkpoint_1000000" - -# `LOSS_NORMALIZING_FACTOR`: When fine-tuning a model that was pre-trained -# using Mesh Tensorflow (e.g. the public T5 / mT5 / ByT5 models), this should be -# set to `pretraining batch_size` * `target_token_length`. For T5 and T5.1.1: -# `2048 * 114`. For mT5: `1024 * 229`. For ByT5: `1024 * 189`. -LOSS_NORMALIZING_FACTOR = 233472 - -train_script.train: - eval_period = 100 - -train_script.train: - train_dataset_cfg = @train/utils.DatasetConfig() - train_eval_dataset_cfg = @train_eval/utils.DatasetConfig() - infer_eval_dataset_cfg = @infer_eval/utils.DatasetConfig() - -models.EncoderDecoderModel.predict_batch_with_aux.num_decodes = 4 - - -infer_eval/utils.DatasetConfig: - mixture_or_task_name = %MIXTURE_OR_TASK_NAME - task_feature_lengths = %TASK_FEATURE_LENGTHS - split = 'validation' - batch_size = 64 - shuffle = False - seed = 42 - pack = False - use_cached = %USE_CACHED_TASKS - module = %MIXTURE_OR_TASK_MODULE - -seqio.Evaluator: - logger_cls = [@seqio.PyLoggingLogger, @seqio.TensorBoardLogger, @seqio.JSONLogger] - num_examples = None # Use all examples in the dataset. - use_memory_cache = True \ No newline at end of file diff --git a/t5x-main/t5x/examples/t5/t5_1_1/examples/base_wmt_from_scratch.gin b/t5x-main/t5x/examples/t5/t5_1_1/examples/base_wmt_from_scratch.gin deleted file mode 100644 index a981a46092ec8e74db7869092003c3ba9dada953..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/t5/t5_1_1/examples/base_wmt_from_scratch.gin +++ /dev/null @@ -1,63 +0,0 @@ -from __gin__ import dynamic_registration - -import __main__ as train_script -import seqio -from t5.data import mixtures -from t5x import models -from t5x import partitioning -from t5x import utils - -include "t5x/examples/t5/t5_1_1/base.gin" -include "t5x/configs/runs/pretrain.gin" - -MIXTURE_OR_TASK_NAME = "wmt_t2t_ende_v003" -TASK_FEATURE_LENGTHS = {"inputs": 256, "targets": 256} -TRAIN_STEPS = 50000 -DROPOUT_RATE = 0.0 - -train/utils.DatasetConfig: - batch_size = 128 - use_cached = False - pack = True - seed = 0 - -train_eval/utils.DatasetConfig: - batch_size = 128 - use_cached = False - pack = True - seed = 0 - -infer_eval/utils.DatasetConfig: - mixture_or_task_name = %MIXTURE_OR_TASK_NAME - task_feature_lengths = None # compute max - split = "validation" - seed = 0 - batch_size = 128 - shuffle = False - use_cached = False - -train_script.train: - eval_period = 500 - eval_steps = 20 - random_seed = 0 - use_hardware_rng = True - infer_eval_dataset_cfg = @infer_eval/utils.DatasetConfig() - inference_evaluator_cls = @seqio.Evaluator - -seqio.Evaluator: - logger_cls = [@seqio.PyLoggingLogger, @seqio.TensorBoardLogger, @seqio.JSONLogger] - num_examples = None # Use all examples in the infer_eval dataset. - use_memory_cache = True - -utils.SaveCheckpointConfig: - period = 5000 # checkpoint frequency - -# `num_decodes` is equivalent to a beam size in a beam search decoding. -models.EncoderDecoderModel.predict_batch_with_aux.num_decodes = 4 - -partitioning.PjitPartitioner.num_partitions = 2 - -utils.create_learning_rate_scheduler: - factors = 'constant * rsqrt_decay' - base_learning_rate = 1.0 - warmup_steps = 10000 diff --git a/t5x-main/t5x/examples/t5/t5_1_1/examples/base_wmt_from_scratch_adamw.gin b/t5x-main/t5x/examples/t5/t5_1_1/examples/base_wmt_from_scratch_adamw.gin deleted file mode 100644 index 5c4ae4468349439a9fabde65c25255af66c94756..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/t5/t5_1_1/examples/base_wmt_from_scratch_adamw.gin +++ /dev/null @@ -1,51 +0,0 @@ -# This gin file is to show how to switch to an optimizer other than -# Adafactor. Gin configuration makes it easy by simply importing any available -# optimizer in t5x/optimizers module. Note the optimizers in t5x/optimizers are -# wrapped version of optimizers implemented in optax. - -from __gin__ import dynamic_registration - -from t5x import optimizers -from t5x import utils -import optax - -include "t5x/examples/t5/t5_1_1/examples/base_wmt_from_scratch.gin" - -# In this case, we choose to switch to the AdamW optimizer with gradient clip. -OPTIMIZER = @optimizers.chain() - -optimizers.chain: - transformations = [@optax.clip(), @optax.adamw()] - -optax.clip: - max_delta = 1.0 - -optax.adamw: - # Unlike Adafactor, most optimizers require to specify - # `learning_rate`. `learning_rate` accepts a float number (e.g., 1e-4) or - # a schedule function, which should take an argument `step` and output - # a learning rate for that step. - # As for choices of schedule functions, we can either use T5x - # learning rate scheduler, i.e., utils.create_learning_rate_scheduler, or - # optax's native schedule functions, e.g., warmup_cosine_decay_schedule. - learning_rate = @optax.warmup_cosine_decay_schedule() - -optax.warmup_cosine_decay_schedule: - init_value = 0.0 - peak_value = 1e-4 - warmup_steps = 1000 - decay_steps = %TRAIN_STEPS - end_value = 0.0 - - -# Below is an example of using the T5X's schedule functions. -# Feel free to uncomment to try. -# optax.adamw: -# learning_rate = @utils.create_learning_rate_scheduler() - -# utils.create_learning_rate_scheduler: -# factors = 'constant * linear_warmup * rsqrt_decay' -# base_learning_rate = 0.01 -# warmup_steps = 10000 - - diff --git a/t5x-main/t5x/examples/t5/t5_1_1/examples/base_wmt_from_scratch_lion.gin b/t5x-main/t5x/examples/t5/t5_1_1/examples/base_wmt_from_scratch_lion.gin deleted file mode 100644 index 1df9cd712f88ded29cc1070acb0f0b566ae34cfc..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/t5/t5_1_1/examples/base_wmt_from_scratch_lion.gin +++ /dev/null @@ -1,51 +0,0 @@ -# This gin file is to show how to switch to an optimizer other than -# Adafactor. Gin configuration makes it easy by simply importing any available -# optimizer in t5x/optimizers module. Note the optimizers in t5x/optimizers are -# wrapped version of optimizers implemented in optax. - -from __gin__ import dynamic_registration - -from t5x import optimizers -from t5x import utils -import optax - -include "t5x/examples/t5/t5_1_1/examples/base_wmt_from_scratch.gin" - -# In this case, we choose to switch to the Lion optimizer with gradient clip. -OPTIMIZER = @optimizers.chain() - -optimizers.chain: - transformations = [@optax.clip(), @optax.lion()] - -optax.clip: - max_delta = 1.0 - -optax.lion: - # Unlike Adafactor, most optimizers require to specify - # `learning_rate`. `learning_rate` accepts a float number (e.g., 1e-4) or - # a schedule function, which should take an argument `step` and output - # a learning rate for that step. - # As for choices of schedule functions, we can either use T5x - # learning rate scheduler, i.e., utils.create_learning_rate_scheduler, or - # optax's native schedule functions, e.g., warmup_cosine_decay_schedule. - learning_rate = @optax.warmup_cosine_decay_schedule() - -optax.warmup_cosine_decay_schedule: - init_value = 0.0 - peak_value = 5e-5 - warmup_steps = 1000 - decay_steps = %TRAIN_STEPS - end_value = 0.0 - - -# Below is an example of using the T5X's schedule functions. -# Feel free to uncomment to try. -# optax.lion: -# learning_rate = @utils.create_learning_rate_scheduler() - -# utils.create_learning_rate_scheduler: -# factors = 'constant * linear_warmup * rsqrt_decay' -# base_learning_rate = 0.01 -# warmup_steps = 10000 - - diff --git a/t5x-main/t5x/examples/t5/t5_1_1/examples/base_wmt_infer.gin b/t5x-main/t5x/examples/t5/t5_1_1/examples/base_wmt_infer.gin deleted file mode 100644 index 73898092b78ae06ba1fec42634b3fe8e3092a6d7..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/t5/t5_1_1/examples/base_wmt_infer.gin +++ /dev/null @@ -1,19 +0,0 @@ -from __gin__ import dynamic_registration - -import __main__ as infer_script -from t5.data import mixtures -from t5x import partitioning -from t5x import utils - -include "t5x/examples/t5/t5_1_1/base.gin" -include "t5x/configs/runs/infer.gin" - -DROPOUT_RATE = 0.0 # unused but needs to be specified -MIXTURE_OR_TASK_NAME = "wmt_t2t_ende_v003" -TASK_FEATURE_LENGTHS = {"inputs": 64, "targets": 64} - -partitioning.PjitPartitioner.num_partitions = 1 - -utils.DatasetConfig: - split = "test" - batch_size = 32 diff --git a/t5x-main/t5x/examples/t5/t5_1_1/examples/small_c4_pretrain.gin b/t5x-main/t5x/examples/t5/t5_1_1/examples/small_c4_pretrain.gin deleted file mode 100644 index 4a4ccbeadbd5f118d6cfe2c0d0385a3f5d233d8e..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/t5/t5_1_1/examples/small_c4_pretrain.gin +++ /dev/null @@ -1,11 +0,0 @@ -include 't5x/examples/t5/t5_1_1/small.gin' -include 't5x/configs/runs/pretrain.gin' - -# Register necessary SeqIO Tasks/Mixtures. -import t5.data.mixtures - -MIXTURE_OR_TASK_NAME = "c4_v220_span_corruption" -TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 114} -TRAIN_STEPS = 10000 -DROPOUT_RATE = 0.0 -BATCH_SIZE = 256 diff --git a/t5x-main/t5x/examples/t5/t5_1_1/examples/small_wmt_finetune.gin b/t5x-main/t5x/examples/t5/t5_1_1/examples/small_wmt_finetune.gin deleted file mode 100644 index 90b45356f7ac75b71ecd69906053f008d3c6af3b..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/t5/t5_1_1/examples/small_wmt_finetune.gin +++ /dev/null @@ -1,22 +0,0 @@ -from __gin__ import dynamic_registration - -import __main__ as train_script -from t5.data import mixtures -from t5x import models -from t5x import partitioning -from t5x import utils - -include "t5x/examples/t5/t5_1_1/small.gin" -include "t5x/configs/runs/finetune.gin" - -MIXTURE_OR_TASK_NAME = "wmt_t2t_ende_v003" -TASK_FEATURE_LENGTHS = {"inputs": 256, "targets": 256} -TRAIN_STEPS = 1_020_000 # 1000000 pre-trained steps + 20000 fine-tuning steps. -DROPOUT_RATE = 0.0 -INITIAL_CHECKPOINT_PATH = "gs://t5-data/pretrained_models/t5x/t5_1_1_small/checkpoint_1000000" -# `LOSS_NORMALIZING_FACTOR`: When fine-tuning a model that was pre-trained -# using Mesh Tensorflow (e.g. the public T5 / mT5 / ByT5 models), this should be -# set to `pretraining batch_size` * `target_token_length`. For T5 and T5.1.1: -# `2048 * 114`. For mT5: `1024 * 229`. For ByT5: `1024 * 189`. -LOSS_NORMALIZING_FACTOR = 233472 -USE_CACHED_TASKS = False diff --git a/t5x-main/t5x/examples/t5/t5_1_1/examples/test_train_eval_t5_tiny.gin b/t5x-main/t5x/examples/t5/t5_1_1/examples/test_train_eval_t5_tiny.gin deleted file mode 100644 index 2dbb764afc749e1d868323eb4f3386ec1a7b84e4..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/t5/t5_1_1/examples/test_train_eval_t5_tiny.gin +++ /dev/null @@ -1,13 +0,0 @@ -# Test config to exercise train.py, very similar to test_train_t5_tiny.gin, -# except this only does evaluation, no training. - -from __gin__ import dynamic_registration - -import __main__ as train_script - -include 't5x/examples/t5/t5_1_1/examples/test_train_t5_tiny.gin' - -train_script.train: - run_eval_before_training = True - eval_period = 0 - total_steps = 0 diff --git a/t5x-main/t5x/examples/t5/t5_1_1/examples/test_train_t5_tiny.gin b/t5x-main/t5x/examples/t5/t5_1_1/examples/test_train_t5_tiny.gin deleted file mode 100644 index 9006ad4dcf107a0a9ae12c7c7113bb772d532cfc..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/t5/t5_1_1/examples/test_train_t5_tiny.gin +++ /dev/null @@ -1,56 +0,0 @@ -# Test config to exercise train.py with model-based pjit partitioning. - -from __gin__ import dynamic_registration - -import __main__ as train_script -from t5x import adafactor -from t5x import models -from t5x import partitioning -from t5x import trainer -from t5x import utils - -include 't5x/configs/runs/pretrain.gin' -include 't5x/examples/t5/t5_1_1/tiny.gin' - -MODEL_DIR = "/tmp" # Will be overridden in test. - -TRAIN_STEPS = 3 -MIXTURE_OR_TASK_MODULE = "t5.data.mixtures" -MIXTURE_OR_TASK_NAME = "wmt19_ende_v003" -TASK_FEATURE_LENGTHS = {"inputs": 32, "targets": 32} -DROPOUT_RATE = 0.0 - -models.EncoderDecoderModel: - z_loss = 0.0 - label_smoothing = 0.0 - loss_normalizing_factor = None - - -train/utils.DatasetConfig: - pack = False - seed = 0 - shuffle = False - use_cached = False - batch_size = 8 - -train_eval/utils.DatasetConfig: - pack = False - seed = 0 - shuffle = False - use_cached = False - batch_size = 8 - -train_script.train: - random_seed = 0 - eval_steps = 2 - actions={'TRAIN_EVAL': [@trainer.TerminateOnNanAction()]} - -trainer.TerminateOnNanAction: - task = %MIXTURE_OR_TASK_NAME - -partitioning.PjitPartitioner.num_partitions = 2 -utils.SaveCheckpointConfig.period = 4 - -# Overriding from pretrain.gin to keep magic constants in tests. -utils.create_learning_rate_scheduler: - warmup_steps = 1000 diff --git a/t5x-main/t5x/examples/t5/t5_1_1/large.gin b/t5x-main/t5x/examples/t5/t5_1_1/large.gin deleted file mode 100644 index 6d92ef41984399ff6cc87b869e45b55fb1860a42..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/t5/t5_1_1/large.gin +++ /dev/null @@ -1,13 +0,0 @@ -# T5.1.1 Large model. - -include 't5x/examples/t5/t5_1_1/base.gin' # imports vocab, optimizer and model. - -# ------------------- Network specification overrides -------------------------- -network.Transformer.config = @network.T5Config() -network.T5Config: - emb_dim = 1024 - num_heads = 16 - num_encoder_layers = 24 - num_decoder_layers = 24 - head_dim = 64 - mlp_dim = 2816 diff --git a/t5x-main/t5x/examples/t5/t5_1_1/small.gin b/t5x-main/t5x/examples/t5/t5_1_1/small.gin deleted file mode 100644 index 1c4f9d0dc6f89fc0bbd88d7116fedf508de8cc03..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/t5/t5_1_1/small.gin +++ /dev/null @@ -1,13 +0,0 @@ -# T5.1.1 Small model. - -include 't5x/examples/t5/t5_1_1/base.gin' # imports vocab, optimizer and model. - -# ------------------- Network specification overrides -------------------------- -network.Transformer.config = @network.T5Config() -network.T5Config: - emb_dim = 512 - num_heads = 6 - num_encoder_layers = 8 - num_decoder_layers = 8 - head_dim = 64 - mlp_dim = 1024 diff --git a/t5x-main/t5x/examples/t5/t5_1_1/tiny.gin b/t5x-main/t5x/examples/t5/t5_1_1/tiny.gin deleted file mode 100644 index ed83eecd0b229ffd8b50561241e268d9cfc3ecfb..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/t5/t5_1_1/tiny.gin +++ /dev/null @@ -1,13 +0,0 @@ -# T5.1.1 tiny model. - -include 't5x/examples/t5/t5_1_1/base.gin' # imports vocab, optimizer and model. - -# ------------------- Network specification overrides -------------------------- -network.Transformer.config = @network.T5Config() -network.T5Config: - emb_dim = 8 - num_heads = 4 - num_encoder_layers = 2 - num_decoder_layers = 2 - head_dim = 3 - mlp_dim = 16 diff --git a/t5x-main/t5x/examples/t5/t5_1_1/xl.gin b/t5x-main/t5x/examples/t5/t5_1_1/xl.gin deleted file mode 100644 index 34f8cd6f312729454480a83822c1d8ff8920c242..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/t5/t5_1_1/xl.gin +++ /dev/null @@ -1,13 +0,0 @@ -# T5.1.1 XL model. - -include 't5x/examples/t5/t5_1_1/base.gin' # imports vocab, optimizer and model. - -# ------------------- Network specification overrides -------------------------- -network.Transformer.config = @network.T5Config() -network.T5Config: - emb_dim = 2048 - num_heads = 32 - num_encoder_layers = 24 - num_decoder_layers = 24 - head_dim = 64 - mlp_dim = 5120 diff --git a/t5x-main/t5x/examples/t5/t5_1_1/xxl.gin b/t5x-main/t5x/examples/t5/t5_1_1/xxl.gin deleted file mode 100644 index 1d4828687bfd79c78e47977bd2ff520efe3f9d1a..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/examples/t5/t5_1_1/xxl.gin +++ /dev/null @@ -1,13 +0,0 @@ -# T5.1.1 XXL model. - -include 't5x/examples/t5/t5_1_1/base.gin' # imports vocab, optimizer and model. - -# ------------------- Network specification overrides -------------------------- -network.Transformer.config = @network.T5Config() -network.T5Config: - emb_dim = 4096 - num_heads = 64 - num_encoder_layers = 24 - num_decoder_layers = 24 - head_dim = 64 - mlp_dim = 10240 diff --git a/t5x-main/t5x/export.py b/t5x-main/t5x/export.py deleted file mode 100644 index 2c832f7a0afad38df1ea86c44aab28a093c218e0..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/export.py +++ /dev/null @@ -1,94 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -r"""Exports a T5X model. - - -""" - -import os -from typing import Sequence -from absl import logging - -# Set Linen to add profiling information when constructing Modules. -# Must be set before flax imports. -# pylint:disable=g-import-not-at-top -os.environ.setdefault('FLAX_PROFILE', 'true') - -from t5x import export_lib - -if __name__ == '__main__': - # pylint:disable=g-import-not-at-top - from absl import app - from absl import flags - import fiddle as fdl - import gin - from t5x import config_utils - from t5x import gin_utils - # pylint:enable=g-import-not-at-top - - FLAGS = flags.FLAGS - - - flags.DEFINE_multi_string( - 'gin_file', - default=None, - help=( - 'Path to gin configuration file. Multiple paths may be passed and ' - 'will be imported in the given order, with later configurations ' - 'overriding earlier ones.' - ), - ) - - flags.DEFINE_multi_string( - 'gin_bindings', - default=[], - help='Individual gin bindings. Also used to integrate gin and XManager.', - ) - - flags.DEFINE_list( - 'gin_search_paths', - default=['t5x/configs'], - help=( - 'Comma-separated list of gin config path prefixes to be prepended ' - 'to suffixes given via `--gin_file`. If a file appears in. Only the ' - 'first prefix that produces a valid path for each suffix will be ' - 'used.' - ), - ) - - def main(argv: Sequence[str]): - """Wrapper for g3pdb post mortems.""" - _main(argv) - - def _main(argv: Sequence[str]): - """True main function.""" - if len(argv) > 1: - raise app.UsageError('Too many command-line arguments.') - - - if config_utils.using_fdl(): - config = config_utils.config_with_fiddle(export_lib.save) - export_with_fiddle = fdl.build(config) - export_with_fiddle() - else: - save_with_gin = gin.configurable(export_lib.save) - - gin_utils.parse_gin_flags( - FLAGS.gin_search_paths, FLAGS.gin_file, FLAGS.gin_bindings - ) - logging.info('Creating inference function...') - save_with_gin() - - config_utils.run(main) diff --git a/t5x-main/t5x/export_lib.py b/t5x-main/t5x/export_lib.py deleted file mode 100644 index 5acc0ed10cd05eed39e41ae5bfdb20645c67cccf..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/export_lib.py +++ /dev/null @@ -1,1666 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Functions for exporting a T5X model.""" - -import dataclasses -import functools -import inspect -import itertools -import json -import os -import os.path -import random -import string -import typing -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union - -from absl import logging -from flax.core import frozen_dict -import flax.traverse_util -import jax -from jax.experimental import jax2tf # type: ignore[import] -import jax.numpy as jnp -import ml_collections -import numpy as np -import seqio -from t5x import checkpoints -from t5x import decoding -from t5x import models -from t5x import partitioning -from t5x import utils -import tensorflow as tf # type: ignore -import typing_extensions - -from tensorflow_serving.apis import predict_pb2 -from tensorflow_serving.apis import prediction_log_pb2 - - -PyTree = Any -ConfigDict = ml_collections.ConfigDict -DecoderParamsSpec = Sequence[Tuple[str, tf.DType, Sequence[int]]] -PreprocessorFn = Callable[..., Mapping[str, tf.Tensor]] -WarmupExample = Union[Union[str, bytes], List[int]] -WarmupExamples = List[WarmupExample] -PostprocessorFn = Callable[ - [Tuple[Any, Any]], Union[Tuple[Any, Any], Mapping[str, Any]] -] -InferenceFn = Callable[[Mapping[str, Any], Any], PyTree] - - -class CreatePreprocessorFnNew(typing_extensions.Protocol): - - def __call__( - self, - batch_size: Optional[int], - output_features: Mapping[str, seqio.Feature], - task_feature_lengths: Mapping[str, int], - tokenized_inputs: bool, - ) -> Tuple[PreprocessorFn, Sequence[tf.TensorSpec]]: - ... - - -# Old signature, for backwards-compatibility. -# TODO(marcrasi): Delete this after migrating clients. -CreatePreprocessorFnOld = Callable[ - [Mapping[str, seqio.Feature], Mapping[str, int], bool], PreprocessorFn -] -CreatePreprocessorFn = Union[CreatePreprocessorFnOld, CreatePreprocessorFnNew] - - -@dataclasses.dataclass -class CustomInferenceMode: - # The name of the model function which can be fetched from - # getattr(model, model_fn_name). - model_fn_name: str - # Fetch useful output from the raw output of the model function. - fetch_output: Optional[Callable[[PyTree], PyTree]] = None - # Constant keyword aguments to append when calling the model function. - model_fn_const_kwargs: None | Mapping[str, Any] = None - - -class CreatePostprocessorFn(typing_extensions.Protocol): - - def __call__( - self, - vocab: seqio.Vocabulary, - inference_mode: Union[str, CustomInferenceMode], - decode_outputs: bool = True, - output_feature_names: Optional[List[str]] = None, - ) -> PostprocessorFn: - ... - - -class CreateDecodingStateCallbackFn(typing_extensions.Protocol): - - def __call__( - self, - vocab: seqio.Vocabulary, - num_decodes: int = 1, - output_feature_names: Optional[List[str]] = None, - call_tf_graph: bool = False, - ) -> decoding.StateCallbackFn: - ... - - -class ExportableModule(tf.Module): - """Wrapper for TF function + parameters to be exported.""" - - def __init__( - self, - preproc_tf_fn, - model_tf_fn, - postproc_tf_fn, - params: Mapping[str, Any], - batch_size: Optional[int], - num_batch_threads: int = 8, - max_enqueued_batches: int = 64, - batch_timeout_micros: int = 1000_000, - max_batch_size: Optional[int] = None, - allowed_batch_sizes: Optional[Sequence[int]] = None, - jit_compile: bool = True, - use_batch_function: bool = False, - use_gpu: bool = False, - enable_large_batch_splitting: bool = True, - ): - super().__init__() - - def flat_params(params): - flat_param_vars = {} - for k, v in flax.traverse_util.flatten_dict(params).items(): - flat_param_vars[k] = tf.Variable( - np.asarray(v), trainable=False, name='__'.join(k) - ) - return flat_param_vars - - if use_gpu: - tf_device = tf.config.list_logical_devices('GPU')[0] - with tf.device(tf_device): - flat_param_vars = flat_params(params) - else: - flat_param_vars = flat_params(params) - self._variables = list(flat_param_vars.values()) - param_vars = frozen_dict.freeze( - flax.traverse_util.unflatten_dict(flat_param_vars) - ) - self._preproc_tf_fn = preproc_tf_fn - self._postproc_tf_fn = postproc_tf_fn - - # TF trackable resources must be assigned to an attribute of the module. - # TODO(dinghua): We should have a more formal API for getting the - # trackable members from pre/post-processing functions. - self._other_trackables = [] - for fn in (self._preproc_tf_fn, self._postproc_tf_fn): - if hasattr(fn, 'trackable_resources'): - self._other_trackables.append(fn.trackable_resources) - - # Note: jit_compile=True also instructs the TPU inference converter v2 to - # wrap this function with `TPUPartitionedCall`. - self._model_tf_fn = tf.function( - lambda x: model_tf_fn(param_vars, x), - autograph=False, - jit_compile=jit_compile, - ) - self._batch_size = batch_size - self._num_batch_threads = num_batch_threads - self._max_enqueued_batches = max_enqueued_batches - self._batch_timeout_micros = batch_timeout_micros - self._allowed_batch_sizes = allowed_batch_sizes - self._use_batch_function = use_batch_function - self._max_batch_size = max_batch_size - self._enable_large_batch_splitting = enable_large_batch_splitting - - @functools.partial(tf.function, autograph=False, jit_compile=False) - def __call__(self, *input_batches) -> Tuple[Any, Any]: - if not self._use_batch_function: - return self._call(*input_batches) - - if self._allowed_batch_sizes: - if self._batch_size is not None: - raise ValueError('allowed_batch_sizes requires polymorphic batch size') - max_batch_size = self._max_batch_size or max(self._allowed_batch_sizes) - allowed_batch_sizes = self._allowed_batch_sizes - elif self._batch_size is not None: - max_batch_size = self._max_batch_size or self._batch_size - allowed_batch_sizes = [self._batch_size] - else: - raise ValueError( - 'Need to set either batch_size or allowed_batch_sizes when ' - 'using batch_function.' - ) - batch_wrapper = tf.nondifferentiable_batch_function( - num_batch_threads=self._num_batch_threads, - max_enqueued_batches=self._max_enqueued_batches, - max_batch_size=max_batch_size, - batch_timeout_micros=self._batch_timeout_micros, - allowed_batch_sizes=allowed_batch_sizes, - enable_large_batch_splitting=self._enable_large_batch_splitting, - ) - flattended, tree_def = jax.tree_util.tree_flatten(input_batches) - return batch_wrapper(functools.partial(self._call, tree_def=tree_def))( - *flattended - ) - - def _call(self, *args, tree_def=None): - if tree_def is not None: - input_batches = jax.tree_util.tree_unflatten(tree_def, args) - else: - input_batches = args - features = self._preproc_tf_fn(*input_batches) - model_output = self._model_tf_fn(features) - return self._postproc_tf_fn(model_output) - - @property - def tpu_func(self): - return self._model_tf_fn - - @property - def export_batch_sizes(self): - return self._allowed_batch_sizes or [self._batch_size] - - @tf.function(autograph=False, jit_compile=False) - def preproc_func(self, *args): - return self._preproc_tf_fn(*args) - - -def get_train_state_initializer( - model: models.BaseTransformerModel, - partitioner: partitioning.BasePartitioner, - task_feature_lengths: Mapping[str, int], - batch_size: Optional[int], - trailing_shapes: Optional[Mapping[str, Tuple[int, ...]]] = None, -) -> Optional[utils.TrainStateInitializer]: - """Creates an TrainStateInitializer based on the model and partitioning.""" - if not partitioner: - return None - - data_layout = partitioner.get_data_layout(batch_size) - p_batch_size = data_layout.batch_size - feature_converter = model.FEATURE_CONVERTER_CLS(pack=False) - model_feature_lengths = feature_converter.get_model_feature_lengths( - task_feature_lengths - ) - input_shapes = {} - for k, l in model_feature_lengths.items(): - input_shapes[k] = (p_batch_size, l) - if feature_converter.MODEL_FEATURES[k].rank > 1: - if trailing_shapes is None or k not in trailing_shapes: - raise ValueError( - 'Must set the trailing shape--`...?` in ' - '`(batch_size, seqlen, ...?)`--for higher rank ' - f'feature {k}' - ) - input_shapes[k] += trailing_shapes[k] - train_state_initializer = utils.TrainStateInitializer( - optimizer_def=None, - init_fn=model.get_initial_variables, - input_shapes=input_shapes, - partitioner=partitioner, - ) - utils.log_model_info( - None, train_state_initializer.global_train_state_shape, partitioner - ) - return train_state_initializer - - -def flatten( - compute_outputs: PyTree, assert_output_len=None -) -> Tuple[jnp.ndarray, ...]: - values, _ = jax.tree_util.tree_flatten(compute_outputs) - if assert_output_len is not None: - assert len(values) == assert_output_len - return tuple(values) - - -_BUILTIN_INFERENCE_MODE_PARAMS = { - 'predict': { - 'model_fn_name': 'predict_batch_with_aux', - 'default_output_len': 2, - }, - 'score': {'model_fn_name': 'score_batch', 'default_output_len': 1}, -} - - -def create_inference_function( - *, - model: models.BaseTransformerModel, - inference_mode: Union[str, CustomInferenceMode], - partitioner: Optional[partitioning.BasePartitioner], - train_state_initializer: Optional[utils.TrainStateInitializer], - decoding_state_callback_fn: Optional[decoding.StateCallbackFn] = None, - enable_jax2tf: bool, - enable_xla: bool = True, - polymorphic_shapes_inputs: Optional[Any] = None, - native_lowering: bool = True, - native_lowering_platforms: Optional[Sequence[str]] = None, - model_fn_extra_kwargs: Optional[Mapping[str, Any]] = None, - jax2tf_disable_platform_checks: bool = False, - output_len: Optional[int] = None, -) -> Callable[[Mapping[str, Any], Any], PyTree]: - """Fetches a model and returns the inference function based on inference_mode.""" - # Always use native serialization. The non-native serialization is deprecated. - del native_lowering - if partitioner and train_state_initializer: - maybe_partition = lambda fn: partitioner.partition( # pylint:disable=g-long-lambda - fn, - # TODO(b/121310741): Re-enable pytype. - # pytype:disable=wrong-arg-types - in_axis_resources=( - train_state_initializer.train_state_axes.params, - partitioning.PartitionSpec( - 'data', - ), - ), - out_axis_resources=partitioning.PartitionSpec( - 'data', - ), - # pytype:enable=wrong-arg-types - ) - - else: - maybe_partition = lambda fn: fn - - if not isinstance(inference_mode, CustomInferenceMode): - if inference_mode not in _BUILTIN_INFERENCE_MODE_PARAMS: - raise ValueError( - '`inference_mode` must be a string in ' - f'{list(_BUILTIN_INFERENCE_MODE_PARAMS.keys())} or a ' - f'`CustomInferenceMode`. Got inference_mode={inference_mode}.' - ) - default_mode_params = _BUILTIN_INFERENCE_MODE_PARAMS[inference_mode] - assert_output_len = ( - default_mode_params['default_output_len'] - if output_len is None - else output_len - ) - inference_mode = CustomInferenceMode( - default_mode_params['model_fn_name'], - fetch_output=functools.partial( - flatten, assert_output_len=assert_output_len - ), - ) - - inference_mode = typing.cast(CustomInferenceMode, inference_mode) - - if inference_mode.model_fn_name == 'predict_batch_with_aux': - # Extract `decoder_params` passed by the preprocessor. Decoder params are - # supported only for `predict_batch_with_aux`. - # - # TODO(b/256173604): Make the following Gin-configurable. - - def model_fn( - params: Mapping[str, Any], inputs: Mapping[str, jnp.ndarray] - ) -> Tuple[Any, Any]: - batch = dict(inputs) - - decoder_params = batch.pop('decoder_params', {}) - if decoding_state_callback_fn is not None: - decoder_params['state_callback_fn'] = decoding_state_callback_fn - - kwargs = {} - if decoder_params: - kwargs['decoder_params'] = decoder_params - if model_fn_extra_kwargs: - kwargs.update(model_fn_extra_kwargs) - # pytype: disable=wrong-keyword-args - return model.predict_batch_with_aux(params, batch, **kwargs) - # pytype: enable=wrong-keyword-args - - else: - model_fn = getattr(model, inference_mode.model_fn_name) - - if inference_mode.model_fn_const_kwargs: - model_fn = functools.partial( - model_fn, **inference_mode.model_fn_const_kwargs - ) - - model_fn = maybe_partition(model_fn) - if enable_jax2tf: - disabled_checks = ( - [jax2tf.DisabledSafetyCheck.platform()] - if jax2tf_disable_platform_checks - else [] - ) - if not native_lowering_platforms: - # Change default value to make the exported cpu model still work. - native_lowering_platforms = ['cpu', 'tpu'] - model_fn = jax2tf.convert( - model_fn, - polymorphic_shapes=[None, polymorphic_shapes_inputs], - native_serialization_platforms=native_lowering_platforms, - native_serialization_disabled_checks=disabled_checks, - enable_xla=enable_xla, - ) - - def inference_fn( - params: Mapping[str, Any], batch: Mapping[str, jnp.ndarray] - ) -> PyTree: - outputs = model_fn(params, batch) - if inference_mode.fetch_output: - outputs = inference_mode.fetch_output(outputs) - return outputs - - return inference_fn - - -def load_params_from_checkpoint( - restore_checkpoint_cfg: utils.RestoreCheckpointConfig, - train_state_initializer: Optional[utils.TrainStateInitializer], - partitioner: partitioning.BasePartitioner, -) -> frozen_dict.FrozenDict: - """Loads the checkpoint and casts the variable.""" - if train_state_initializer is not None: - restore_cfg, ckpt_paths = utils.get_first_valid_restore_config_and_paths( - [restore_checkpoint_cfg] - ) - if len(ckpt_paths) != 1: - raise ValueError( - f'Expected only 1 checkpoint but got {len(ckpt_paths)} for ' - f'config(s): {restore_cfg}' - ) - train_state, _ = utils.create_checkpoint_manager_and_restore( - train_state_initializer=train_state_initializer, - partitioner=partitioner, - restore_checkpoint_cfg=restore_cfg, - restore_path=ckpt_paths[0], - fallback_init_rng=jax.random.PRNGKey(0), - save_checkpoint_cfg=None, - model_dir=None, - ds_iter=None, - use_orbax=True, - ) - return train_state.params # pytype:disable=attribute-error - else: - if restore_checkpoint_cfg.mode != 'specific': - raise NotImplementedError("Only mode='specific' is currently supported") - if not isinstance(restore_checkpoint_cfg.path, str): - raise NotImplementedError('Only string paths are currently supported') - variables = checkpoints.load_t5x_checkpoint( - path=restore_checkpoint_cfg.path, - state_transformation_fns=( - restore_checkpoint_cfg.state_transformation_fns - ), - restore_dtype=jnp.dtype(restore_checkpoint_cfg.dtype), - ) - return frozen_dict.freeze(variables['target']) - - -def create_single_tensor_input_signature( - batch_size: Optional[int], - task_feature_lengths: Mapping[str, int], - tokenized_inputs: bool = False, - name='text_batch', -) -> Sequence[tf.TensorSpec]: - """Returns an input signature for a model that takes a single input tensor. - - Args: - batch_size: Batch size for model to process. If None, then batch - polymorphism is invoked. - task_feature_lengths: Mapping from 'inputs' and 'targets' to sequence - lengths. - tokenized_inputs: specifies whether the input is expected to be - pre-tokenized. If so, the preprocessor expects an int32 tensor of shape - [B, N] rather than a string tensor of shape [B]. - name: the name of the single `tf.TensorSpec` in the input signature. - """ - if tokenized_inputs: - inputs_length = task_feature_lengths['inputs'] - return (tf.TensorSpec([batch_size, inputs_length], tf.int32, name=name),) - else: - return (tf.TensorSpec([batch_size], tf.string, name=name),) - - -def bucketize_tokenized_input( - input_tensor: tf.Tensor, bucket_keys: List[int] -) -> tf.Tensor: - """Returns a scalar tf.Tensor indicates the bucketized length of the input_tensor. - - Args: - input_tensor: A tokenized tf.Tensor. - bucket_keys: a bucket of sequence lengths, sorted ascendingly. - """ - index = tf.searchsorted(bucket_keys, [tf.shape(input_tensor)[0]]) - index = tf.math.minimum(index, len(bucket_keys) - 1) - index = tf.squeeze(index) - return tf.constant(bucket_keys)[index] - - -def truncate_and_pad_tokenized_input( - input_tensor: tf.Tensor, - max_length: int, - allowed_lengths: Optional[list[int]] = None, -) -> tf.Tensor: - """Truncate the examples to the maximum feature length, and pad to bucketed sequence lengths.""" - if allowed_lengths: - allowed_lengths.sort() - if max_length != allowed_lengths[-1]: - raise ValueError( - 'Expected the largest allowed length to be the same as the task' - f' feature length {max_length}, but got {allowed_lengths[-1]}.' - ) - max_length = bucketize_tokenized_input(input_tensor, allowed_lengths) - input_tensor = input_tensor[:max_length] - input_tensor = tf.pad( - input_tensor, [[0, max_length - tf.shape(input_tensor)[0]]] - ) - if not allowed_lengths: - input_tensor.set_shape([max_length]) - return input_tensor - - -def _build_ragged_feature_length_spec( - feature_key: str, - bucket_keys: Optional[Mapping[str, list[int]]], - task_feature_lengths: Mapping[str, int], -) -> tf.RaggedTensorSpec: - """Build a spec of the feature tensor's sequence length after preprocessing. - - The sequence length will be dynamic when polymorphic sequence length buckets - are configured for the given feature. - Args: - feature_key: The name of an input feature tensor. - bucket_keys: A mapping from feature key to allowed sequence lengths. - task_feature_lengths: A mapping from feature key to fixed sequence length. - - Returns: - A RaggedTensorSpec representing the sequence length dimension of the - feature tensor. - """ - if bucket_keys and feature_key in bucket_keys: - # Polymorphic sequence length. - feature_length = None - else: - feature_length = task_feature_lengths[feature_key] - return tf.RaggedTensorSpec( - shape=[feature_length], - dtype=tf.int32, - ragged_rank=0, - ) - - -# TODO(danielandor): More principled score-mode input format. -def create_preprocessor( - batch_size: Optional[int], - output_features: Mapping[str, seqio.Feature], - task_feature_lengths: Mapping[str, int], - tokenized_inputs: bool = False, - *, - input_tensor_name: str = 'text_batch', - split_separator: Optional[str] = None, - bucket_keys: Optional[Mapping[str, List[int]]] = None, -) -> Tuple[PreprocessorFn, Sequence[tf.TensorSpec]]: - """Builds a function based on the config task to tokenize and batch the input text. - - Args: - batch_size: Batch size for model to process. If None, then batch - polymorphism is invoked. - output_features: Mapping from 'inputs' and 'targets' to seqio.Feature. - task_feature_lengths: Mapping from 'inputs' and 'targets' to sequence - lengths. - tokenized_inputs: specifies whether the input is expected to be - pre-tokenized. If so, the preprocessor expects an int32 tensor of shape - [B, N] rather than a string tensor of shape [B]. - input_tensor_name: the name of the input tensor. - split_separator: If given, splits the input text at the first separator, and - sets the target text for scoring to the second element. If None, the - target is set to the empty string. The latter is appropriate for predict - mode. - bucket_keys: If given, bucketizes the tensors according to bucket_keys. - - Returns: - The preprocessor function. - """ - - def preprocess(input_texts: tf.Tensor) -> Mapping[str, tf.Tensor]: - """TF-based preprocessor that takes a batch of text and converts it to model features.""" - if tokenized_inputs: - inputs = input_texts # actually an int32 tensor of shape [B, N]. - targets = tf.broadcast_to( - tf.constant(0, dtype=tf.int32), tf.shape(input_texts) - ) - elif split_separator is None: - inputs = input_texts - targets = tf.broadcast_to(tf.constant(''), tf.shape(input_texts)) - else: - ragged_split = tf.strings.split( - input_texts, sep=split_separator, maxsplit=1 - ) - split = ragged_split.to_tensor(shape=[tf.shape(input_texts)[0], 2]) - inputs, targets = split[:, 0], split[:, 1] - - # TODO(b/188656799): Generalize this code to work with arbitrary models. - def featurize(text, k): - """Replicates what tokenization + seqio.EncDecFeatureConverter does, without Dataset.""" - vocab = output_features[k].vocabulary # type: seqio.Vocabulary - length = task_feature_lengths[k] - if not tokenized_inputs: # if inputs are tokenized, we don't re-tokenize. - t = vocab.encode_tf(text) - else: - t = text - if output_features[k].add_eos: - # The following matches the default behavior of the prediction server, - # which uses seqio.preprocessors.append_eos_after_trim, implemented at: - # https://github.com/google/seqio/tree/main/seqio/preprocessors.py;l=250;rcl=480228505 - t = tf.concat([t[: length - 1], [vocab.eos_id]], axis=0) - allowed_lengths = bucket_keys.get(k) if bucket_keys else None - t = truncate_and_pad_tokenized_input( - t, max_length=length, allowed_lengths=allowed_lengths - ) - ar_inputs = seqio.feature_converters.autoregressive_inputs(t) - loss_weights = seqio.feature_converters.non_padding_position(t) - return t, ar_inputs, loss_weights - - encoder_output_signature = _build_ragged_feature_length_spec( - 'inputs', bucket_keys, task_feature_lengths - ) - encoder_input_tokens, _, _ = tf.map_fn( - functools.partial(featurize, k='inputs'), - inputs, - fn_output_signature=(encoder_output_signature,) * 3, - ) - encoder_input_tokens = encoder_input_tokens.to_tensor() - - decoder_output_signature = _build_ragged_feature_length_spec( - 'targets', bucket_keys, task_feature_lengths - ) - decoder_target_tokens, decoder_input_tokens, loss_weights = tf.map_fn( - functools.partial(featurize, k='targets'), - targets, - fn_output_signature=(decoder_output_signature,) * 3, - ) - decoder_target_tokens = decoder_target_tokens.to_tensor() - decoder_input_tokens = decoder_input_tokens.to_tensor() - loss_weights = loss_weights.to_tensor() - - return dict( - encoder_input_tokens=encoder_input_tokens, - decoder_target_tokens=decoder_target_tokens, - decoder_input_tokens=decoder_input_tokens, - decoder_loss_weights=loss_weights, - ) - - input_signature = create_single_tensor_input_signature( - batch_size, task_feature_lengths, tokenized_inputs, input_tensor_name - ) - return preprocess, input_signature - - -def create_dual_encoder_preprocessor( - batch_size: Optional[int], - output_features: Mapping[str, seqio.Feature], - task_feature_lengths: Mapping[str, int], - tokenized_inputs: bool = False, - input_tensor_name: str = 'text_batch', - bucket_keys: Optional[Mapping[str, List[int]]] = None, - split_separator: Optional[str] = None, -) -> Tuple[PreprocessorFn, Sequence[tf.TensorSpec]]: - """Builds a function based on the config task to tokenize and batch the input text.""" - - def preprocess(input_texts: tf.Tensor) -> Mapping[str, tf.Tensor]: - """TF-based preprocessor that takes a batch of text and converts it to model features.""" - if tokenized_inputs: - inputs = input_texts - targets = tf.broadcast_to( - tf.constant(0, dtype=tf.int32), tf.shape(input_texts) - ) - elif split_separator is None: - inputs = input_texts - targets = tf.broadcast_to(tf.constant(''), tf.shape(input_texts)) - else: - ragged_split = tf.strings.split( - input_texts, sep=split_separator, maxsplit=1 - ) - split = ragged_split.to_tensor(shape=[tf.shape(input_texts)[0], 2]) - inputs, targets = split[:, 0], split[:, 1] - - # TODO(b/188656799): Generalize this code to work with arbitrary models. - def featurize(text, k): - """Replicates what tokenization + nlp.nlx.t5x_retrieval.DualEncoderFeatureConverter does, without Dataset.""" - vocab = output_features[k].vocabulary # type: seqio.Vocabulary - length = task_feature_lengths[k] - if not tokenized_inputs: # if inputs are tokenized, we don't re-tokenize. - t = vocab.encode_tf(text) - else: - t = text - if output_features[k].add_eos: - t = tf.concat([t[: length - 1], [vocab.eos_id]], axis=0) - allowed_lengths = bucket_keys.get(k) if bucket_keys else None - t = truncate_and_pad_tokenized_input( - t, max_length=length, allowed_lengths=allowed_lengths - ) - return t - - left_encoder_input_tokens = tf.map_fn( - functools.partial(featurize, k='inputs'), - inputs, - fn_output_signature=_build_ragged_feature_length_spec( - 'inputs', bucket_keys, task_feature_lengths - ), - ).to_tensor() - - right_encoder_input_tokens = tf.map_fn( - functools.partial(featurize, k='targets'), - targets, - fn_output_signature=_build_ragged_feature_length_spec( - 'targets', bucket_keys, task_feature_lengths - ), - ).to_tensor() - - return dict( - left_encoder_input_tokens=left_encoder_input_tokens, - right_encoder_input_tokens=right_encoder_input_tokens, - ) - - input_signature = create_single_tensor_input_signature( - batch_size, task_feature_lengths, tokenized_inputs, input_tensor_name - ) - return preprocess, input_signature - - -def create_decoder_preprocessor( - output_features: Mapping[str, seqio.Feature], - task_feature_lengths: Mapping[str, int], - tokenized_inputs: bool = False, - input_feature: str = 'inputs', -) -> PreprocessorFn: - """Returns a function to tokenize and featurize inputs for decoder only models. - - Args: - output_features: Mapping from 'inputs' and 'targets' to seqio.Feature. - task_feature_lengths: Mapping from 'inputs' and 'targets' to sequence - lengths. - tokenized_inputs: specifies whether the input is expected to be - pre-tokenized. If so, the preprocessor expects an int32 tensor padded with - 0s to shape of [B, N] rather than a string tensor of shape [B]. - input_feature: Name of the feature provided by `input_texts`, e.g., 'inputs' - or 'targets'. - """ - - def preprocess(input_texts: tf.Tensor) -> Mapping[str, tf.Tensor]: - """TF-based preprocessor that takes a batch of text and converts it to model features.""" - - def tokenize(text): - feature = output_features[input_feature] - vocab = feature.vocabulary # type: seqio.Vocabulary - if not tokenized_inputs: # if inputs are tokenized, we don't re-tokenize. - t = vocab.encode_tf(text) - else: - t = text - if feature.add_eos: - t = tf.concat([t, [vocab.eos_id]], axis=-1) - return t - - decoder_tokens = tf.map_fn( - tokenize, - input_texts, - fn_output_signature=(tf.int32), - ) - - if input_feature == 'inputs': - # 'inputs_width' is the length of 'inputs' (excluding padding 0). - ragged_input_tokens = tf.RaggedTensor.from_tensor( - decoder_tokens, padding=0 - ) - inputs_length = tf.cast(ragged_input_tokens.row_lengths(), dtype=tf.int32) - inputs_width = tf.expand_dims(inputs_length, -1) - inputs_width_add_pos = inputs_width + 1 - else: - inputs_width = tf.zeros(tf.shape(decoder_tokens)[0], dtype=tf.int32) - inputs_width_add_pos = inputs_width - - def featurize(text, length): - text = text[:length] - text = tf.pad(text, [[0, length - tf.shape(text)[0]]]) - text.set_shape([length]) - ar_inputs = seqio.feature_converters.autoregressive_inputs(text) - loss_weights = seqio.feature_converters.non_padding_position(text) - - return text, ar_inputs, loss_weights - - targets_length = sum(task_feature_lengths.values()) - decoder_target_tokens, decoder_input_tokens, decoder_loss_weights = ( - tf.map_fn( - functools.partial(featurize, length=targets_length), - decoder_tokens, - fn_output_signature=(tf.int32, tf.int32, tf.int32), - ) - ) - positions = tf.range(tf.shape(decoder_target_tokens)[-1]) - positions = tf.repeat( - [positions], tf.shape(decoder_target_tokens)[0], axis=0 - ) - - decoder_causal_attention = tf.cast( - positions < inputs_width_add_pos, dtype=decoder_target_tokens.dtype - ) - - inputs = positions < inputs_width - padding_mask = tf.cast(decoder_loss_weights, dtype=tf.bool) - - decoder_loss_weights = tf.cast( - tf.math.logical_xor(inputs, padding_mask), - dtype=decoder_target_tokens.dtype, - ) - - return dict( - decoder_input_tokens=decoder_input_tokens, - decoder_target_tokens=decoder_target_tokens, - decoder_loss_weights=decoder_loss_weights, - decoder_causal_attention=decoder_causal_attention, - ) - - return preprocess - - -def _default_value_for_spec(v): - return tf.zeros(v.shape, v.dtype).numpy() - - -def _feature_description_from_element_spec(element_spec): - """Feature description from element spec.""" - feature_description = {} - for k, v in element_spec.items(): - if isinstance(v, tf.SparseTensorSpec): - feature_description[k] = tf.io.VarLenFeature(dtype=v.dtype) - elif isinstance(v, tf.TensorSpec): - if v.shape.is_fully_defined(): - feature_description[k] = tf.io.FixedLenFeature( - shape=v.shape, - dtype=v.dtype, - default_value=_default_value_for_spec(v), - ) - else: - if v.shape[0] is None and v.shape[1:].is_fully_defined(): - # We only parse single examples (not batches) so the - # FixeLenSequenceFeature will never need to add padding through - # `default_value`. - feature_description[k] = tf.io.FixedLenSequenceFeature( - shape=v.shape[1:], dtype=v.dtype, allow_missing=True - ) - else: - raise ValueError( - 'Except for the first dimension, all dimentions of shape for ' - f'feature {k} need to be known but received {v.shape!s}.' - ) - else: - raise ValueError( - f'Cannot generate feature description for feature "{k}" with ' - f'element spec type {type(v)}; ' - 'supported types: tf.SparseTensorSpec, tf.TensorSpec.' - ) - return feature_description - - -class PreprocessorFnFromTask(object): - """A PreprocessorFn based on seqio.Task.""" - - def __init__( - self, - batch_size: Optional[int], - model: models.BaseTransformerModel, - task_feature_lengths: Mapping[str, int], - task_name: str = '', - serialized_examples: bool = True, - run_precache: bool = False, - ): - self.task = seqio.TaskRegistry.get(task_name) - if serialized_examples: - ds = self.task.source.get_dataset(self.task.splits[0]) - feature_description = _feature_description_from_element_spec( - ds.element_spec - ) - self.parse_example = functools.partial( - tf.io.parse_single_example, features=feature_description - ) - else: - self.parse_example = lambda x: x - - self.feature_converter = model.FEATURE_CONVERTER_CLS(pack=False) - self.task_feature_lengths = task_feature_lengths - self.batch_size = batch_size - self.run_precache = run_precache - - def is_trackable_resource(x): - return isinstance(x, tf.saved_model.experimental.TrackableResource) - - self.trackable_resources = list() - for p in self.task.preprocessors: - # TODO(dinghua): We should have a more formal API for getting the - # trackable members from a seqio preprocessor. - for _, tr in inspect.getmembers(p, is_trackable_resource): - self.trackable_resources.append(tr) - - def process_fn(self, examples: tf.Tensor) -> Mapping[str, tf.Tensor]: - """Converts serialized tf.Examples to batched model features. - - Args: - examples: batch examples. If `self.batch_size` is not None, - `examples.shape[0]` must be the same as `self.batch_size`. - - Returns: - A Mapping from feature names to batch features. - """ - ds = tf.data.Dataset.from_tensor_slices(examples) - # Dataset of parsed tf Examples. - ds = ds.map(self.parse_example) - if self.run_precache: - ds = self.task.preprocess_precache(ds) - ds = self.task.preprocess_postcache(ds, self.task_feature_lengths) - # Dataset of batched model features. - ds = self.feature_converter( - ds, task_feature_lengths=self.task_feature_lengths - ) - # Assume all batch size are the same. - single_feature = jax.tree_util.tree_leaves(examples)[0] - if self.batch_size is not None: - single_feature.shape[:1].assert_is_compatible_with([self.batch_size]) - ds = ds.batch(self.batch_size, drop_remainder=True) - else: - batch_size = tf.cast(tf.shape(single_feature)[0], dtype=tf.int64) - ds = ds.batch(batch_size, drop_remainder=True) - # As we process one batch at a time, the dataset ds has a single batch. - return ds.get_single_element() - - def __call__(self, examples: tf.Tensor) -> Mapping[str, tf.Tensor]: - return self.process_fn(examples) - - -def create_preprocessor_from_task( - batch_size: Optional[int], - output_features: Mapping[str, seqio.Feature], # unused - task_feature_lengths: Mapping[str, int], - tokenized_inputs: bool, - *, - model: models.BaseTransformerModel, - task_name: str = '', - serialized_examples: bool = True, - run_precache: bool = False, - input_tensor_name: str = 'text_batch', -) -> Tuple[PreprocessorFn, Sequence[tf.TensorSpec]]: - """Create a preprocessor based on a seqio task.""" - del output_features - return PreprocessorFnFromTask( - batch_size, - model, - task_feature_lengths, - task_name, - serialized_examples, - run_precache, - ), create_single_tensor_input_signature( - batch_size, task_feature_lengths, tokenized_inputs, input_tensor_name - ) - - -def create_preprocessor_with_decoder_params( - batch_size: Optional[int], - output_features: Mapping[str, seqio.Feature], # unused - task_feature_lengths: Mapping[str, int], - tokenized_inputs: bool, - *, - create_preprocessor_fn: CreatePreprocessorFn, - decoder_params_spec: DecoderParamsSpec, -) -> Tuple[PreprocessorFn, Sequence[tf.TensorSpec]]: - """Creates a preprocessor and adds decoder params as inputs. - - Args: - batch_size: See `save`. - output_features: See `save`. - task_feature_lengths: See `save`. - tokenized_inputs: See `save`. - create_preprocessor_fn: A function that creates a preprocessor to be - wrapped. - decoder_params_spec: A sequence of `(name, dtype, per_example_shape)` for - decoder params to be exposed as inputs. The decoder must be able to accept - the listed decoder params on a per-example basis, i.e., the shape of each - decoder param will be [batch_size, *per_example_shape]. Decoder params are - appended to the inputs in the specified order. - - Returns: - A preprocessor that calls `create_preprocessor_fn(...)` with additional - inputs representing decoder params and adds the specified `decoder_params` - as a new feature. - """ - - # TODO(marcrasi): Delete after migrating clients. - if 'batch_size' in inspect.signature(create_preprocessor_fn).parameters: - # New signature. - preprocessor, input_signature = create_preprocessor_fn( - batch_size, output_features, task_feature_lengths, tokenized_inputs - ) # type: ignore - else: - # Old signature. - preprocessor = create_preprocessor_fn( - output_features, task_feature_lengths, tokenized_inputs - ) # type: ignore - input_signature = create_single_tensor_input_signature( - batch_size, task_feature_lengths, tokenized_inputs - ) - - def wrapped(*args: tf.Tensor) -> Mapping[str, tf.Tensor]: - # Splice the args into inputs and decoder params. - num_decoder_params = len(decoder_params_spec) - decoder_params_values = args[-num_decoder_params:] - inputs = args[:-num_decoder_params] - - features = dict(preprocessor(*inputs)) - - # Add decoder params as additional features. They are removed from the - # features dict in `create_inference_function`. - decoder_params = {} - for (name, _, _), value in zip(decoder_params_spec, decoder_params_values): - decoder_params[name] = value - features['decoder_params'] = decoder_params - - return features - - input_signature = tuple(input_signature) + tuple( - tf.TensorSpec((batch_size,) + tuple(per_example_shape), dtype, name=name) - for name, dtype, per_example_shape in decoder_params_spec - ) - return wrapped, input_signature - - -def _maybe_name_outputs( - feature_values: Tuple[Any, ...], feature_names: Optional[List[str]] -) -> Union[Tuple[Any, ...], Mapping[str, Any]]: - """Names the output features if feature_names are specified.""" - if feature_names is None: - # Even in single arg case, the returned sequence is going to make sure that - # we have consistent behaviors. - return feature_values - if len(feature_values) != len(feature_names): - raise ValueError( - f'Output feature names {feature_names} must match ' - f'number of outputs {len(feature_values)}' - ) - return dict(zip(feature_names, feature_values)) - - -def create_postprocessor( - vocab: seqio.Vocabulary, - inference_mode: Union[str, CustomInferenceMode], - decode_outputs: bool = True, - output_feature_names: Optional[List[str]] = None, - add_token_length_to_output: Optional[bool] = False, -) -> PostprocessorFn: - """Creates a TF postprocessor function. - - Args: - vocab: The vocab to use to decode. - inference_mode: 'predict', 'score' or a CustomInferenceMode instance. - decode_outputs: whether to decode output tokens. - output_feature_names: A list of names to name the output for the savedmodel. - e.g., ['output_a', 'output_b'] will tag the savedmodel output to obtain - two entries with 'output_a' and 'output_b'. The order must match the - outputs from the module. - add_token_length_to_output: A boolean to indicate whether to include token - length infomation in the output. If True, will append 'num_tokens' and - 'num_tokens_before_eos' to the output. 'num_tokens' indicates the length - of the decoder's output. 'num_tokens_before_eos' will count the tokens - until the first 'vocab.eos_id' (excluding the 'eos_id'), or the same as - 'num_tokens' if 'vocab.eos_id' is None. - - Returns: - A function that that post processing on inference outputs. - """ - if inference_mode == 'predict': - - def postprocessor( - values: Tuple[Any, Any], - ) -> Union[Tuple[Any, Any], Mapping[str, Any]]: - tokens, scores = values - if decode_outputs: - decoded = vocab.decode_tf(tokens) - if add_token_length_to_output: - num_tokens = tf.reduce_sum( - tf.ones_like(tokens, dtype=tf.int32), axis=-1 - ) - if vocab.eos_id: - after_eos = tf.cumsum( - tf.cast(tf.equal(tokens, vocab.eos_id), tf.int32), - axis=-1, - ) - before_eos = tf.cast( - tf.logical_not(tf.cast(after_eos, tf.bool)), tf.int32 - ) - num_tokens_before_eos = tf.reduce_sum(before_eos, axis=-1) - else: - num_tokens_before_eos = num_tokens - if isinstance(decoded, tf.RaggedTensor): - decoded = decoded.to_tensor() - return _maybe_name_outputs( - feature_values=( - decoded, - scores, - num_tokens, - num_tokens_before_eos, - ), - feature_names=output_feature_names, - ) - # If add_eos=False, vocab.decode_tf returns a tf.Tensor rather than - # a tf.RaggedTensor. - if isinstance(decoded, tf.RaggedTensor): - decoded = decoded.to_tensor() - return _maybe_name_outputs( - feature_values=(decoded, scores), feature_names=output_feature_names - ) - else: - return _maybe_name_outputs( - feature_values=(tokens, scores), feature_names=output_feature_names - ) - - return postprocessor - else: - return functools.partial( - _maybe_name_outputs, feature_names=output_feature_names - ) - - - - -def _request_for_batch( - text_batch: WarmupExamples, - model_name: str, - input_tensor_name: str, - signature_name: str, - batch_size: Optional[int], - decoder_params_spec: Optional[DecoderParamsSpec] = None, - input_tensor_dtype: Optional[tf.DType] = None, -) -> predict_pb2.PredictRequest: - """Adds a single batch of Predict warmup data.""" - request = predict_pb2.PredictRequest() - request.model_spec.name = model_name - request.model_spec.signature_name = signature_name - if input_tensor_dtype is not None: - dtype = input_tensor_dtype - elif text_batch and isinstance(text_batch[0], (str, bytes)): - dtype = tf.string - else: - dtype = tf.int32 - # Truncate/Pad the request to have batch_size. - adjusted_batch = text_batch - if batch_size is not None: - adjusted_batch = list( - itertools.islice(itertools.cycle(text_batch), batch_size) - ) - request.inputs[input_tensor_name].CopyFrom( - tf.make_tensor_proto(adjusted_batch, dtype=dtype) - ) - if decoder_params_spec is not None: - for name, dtype, per_example_shape in decoder_params_spec: - request.inputs[name].CopyFrom( - tf.make_tensor_proto( - tf.zeros((len(adjusted_batch),) + tuple(per_example_shape), dtype) - ) - ) - return request - - -def _request_to_prediction_log( - request: predict_pb2.PredictRequest, -) -> prediction_log_pb2.PredictionLog: - """Creates a PredictionLog for the Predict method.""" - return prediction_log_pb2.PredictionLog( - predict_log=prediction_log_pb2.PredictLog(request=request) - ) - - -def generate_examples_with_sequence_lengths( - sequence_lengths: list[int], - single_token_example: WarmupExample | None = None, - vocabulary: seqio.Vocabulary | None = None, - character_set: str | Sequence[str] = string.ascii_letters + string.digits, - leave_room_for_eos: bool = False, - chars_per_token_upper_bound: int = 10, -) -> list[WarmupExamples]: - """Creates synthetic sequences that have the requested number of tokens. - - The examples will be computed by one of the following methods: - - If `single_token_example` is set: repeat a single token that is known to - be always `N` tokens long when repeated `N` times. - - If `vocabulary` is set: generate a random string that is `N` tokens long, as - measured by `vocabulary`. - - Args: - sequence_lengths: The sequence lengths to generate examples for. - single_token_example: An example such that `N*ex` is always `N` tokens long. - This is used to build sequences of a specified size. Defaults to `'Q'`, - which satisfies this property for the tokenizer used by pretrained English - T5X models. **NOTE**: This is brittle to variations in the tokenizer, so - prefer using `vocabulary` instead. - vocabulary: The seqio.Vocabulary used by the model. - character_set: The set of characters to use when generating random strings. - Defaults to letters and digits. - leave_room_for_eos: Whether the model will add EOS after the example. If - true, the generated examples will be one token shorter than the requested - length. - chars_per_token_upper_bound: The upper bound for the amount of characters - contained in a single token for this vocabulary. This determines how large - of a random string to start with when trying to generate a specific number - of tokens. - - Returns: - A list of WarmupExamples batches with lengths in tokens equal to - `sequence_lengths`. - """ - # Warmup examples should be deterministic. - random.seed(0) - if leave_room_for_eos: - sequence_lengths = [l - 1 for l in sequence_lengths] - if single_token_example and vocabulary: - raise ValueError( - 'Only one of `single_token_example` and `vocabulary` can be set.' - ) - elif vocabulary: - # Generate a random string that is exactly `N` tokens long, as measured by - # the provided tokenizer. - # TODO: b/331419045 - Add support for models with pretokenized inputs - # (dtype=tf.int32). - def _generate_example(num_tokens: int) -> WarmupExamples: - if num_tokens == 0: - return [''] - random_string = ''.join( - random.choice(character_set) - for _ in range(num_tokens * chars_per_token_upper_bound) - ) - all_ids = vocabulary.encode(random_string) - if len(all_ids) < num_tokens: - # Even if chars_per_token_upper_bound is set high enough, this can - # happen with unknown tokens. See b/294826076#comment4 (encoding Chinese - # characters with an English tokenizer). - raise ValueError( - 'Generated a random warmup example that is shorter than' - f' {num_tokens} tokens. Make sure the characters in character_set' - ' are valid in the vocabulary, or increase' - ' chars_per_token_upper_bound.' - ) - for start_index in range(len(all_ids) - num_tokens): - # Truncating may return an empty string for some IDs, for example - # vocabulary.decode([3]), so search for an ID subarray whose - # resulting string actually has the correct length. - truncated_ids = all_ids[start_index : start_index + num_tokens] - example = vocabulary.decode(truncated_ids) - generated_num_tokens = len(vocabulary.encode(example)) - if generated_num_tokens == num_tokens: - return [example] - raise ValueError( - f'Could not generate a valid string with {num_tokens} tokens. This' - ' may happen if the characters in character_set do not represent the' - ' vocabulary, or if chars_per_token_upper_bound is too small.' - ) - - return [_generate_example(length) for length in sequence_lengths] - else: - single_token_example = single_token_example or 'Q' - logging.warning( - 'Using single_token_example to generate warmup examples is brittle to' - ' variations in the tokenizer. Prefer explicitly passing a vocabulary' - ' instead.' - ) - return [[single_token_example * l] for l in sequence_lengths] - - -def write_warmup_examples( - text_batch: WarmupExamples, - output_dir: str, - model_name: str, - signature_name: str, - *, - generate_examples_fn: Optional[Callable[[], list[WarmupExamples]]] = None, - batch_sizes: List[Optional[int]], - input_tensor_name: str = 'text_batch', - decoder_params_spec: Optional[DecoderParamsSpec] = None, - request_to_prediction_log: Callable[ - [predict_pb2.PredictRequest], prediction_log_pb2.PredictionLog - ] = _request_to_prediction_log, - input_tensor_dtype: Optional[tf.DType] = None, -): - """Writes warmup examples for all batch_sizes requested. - - The text_batch is either filled to batch_size or truncated based on the - different batch_sizes. - For example, if text_batch has length 2 while requested batch_size is 4, it is - repeated two times. If text_batch has length 2 while requested batch_size is - 1, it is truncated to length 1. - - Args: - text_batch: A batch of texts used as warmup examples. - output_dir: The directory for writing the warmup examples to. - model_name: The name of the savedmodel spec. - signature_name: Optional name of the exported function. - generate_examples_fn: An optional no-arg function that generates a list of - synthetic warmup batches. If provided, `text_batch` will be ignored, and - all synthetic batches will be written out, for all batch sizes. NOTE: This - differs from `text_batch`, which only supplies a single batch of examples - that are written together, instead of many batches that are written - separately. - batch_sizes: A list of batch sizes to warmup with. The written number of - tfrecords will be equal to the size of batch_sizes. The list might contain - None entries, and the warmup examples for the None entry won't be padded - or truncated. - input_tensor_name: The entry name of the PredictRequest inputs dict. - decoder_params_spec: The parameter specifciations on decoding. If present, - dummy data (0s) with specified shape/dtype will be written into warmup - examples. - request_to_prediction_log: A function that creates a PredictionLog from a - given request. - input_tensor_dtype: The dtype of the input tensor. - """ - if generate_examples_fn: - if text_batch: - logging.warning( - 'Ignoring provided warmup batch. Using `generate_examples_fn` to' - ' generate warmup examples instead.' - ) - warmup_examples = generate_examples_fn() - else: - warmup_examples = [text_batch] - - assets_extra = os.path.join(output_dir, 'assets.extra') - tf.io.gfile.makedirs(assets_extra) - warmup_output = os.path.join(assets_extra, 'tf_serving_warmup_requests') - with tf.io.TFRecordWriter(warmup_output) as writer: - for warmup_example in warmup_examples: - for batch_size in batch_sizes: - logging.info('Writing warmup data for batch size: %s ...', batch_size) - request = _request_for_batch( - warmup_example, - model_name, - input_tensor_name, - signature_name, - batch_size, - decoder_params_spec, - input_tensor_dtype, - ) - log = request_to_prediction_log(request) - writer.write(log.SerializeToString()) - - - - -def _standardize_output_features( - mixture_or_task_name: Optional[str], - output_features: Optional[Mapping[str, seqio.Feature]], -): - """Standarize the output_features from user inputs.""" - new_output_features = output_features - if mixture_or_task_name is not None and output_features is not None: - raise ValueError( - 'Only one of mixture_or_task_name and output_features may be non empty.' - ) - if mixture_or_task_name is not None: - logging.info('Fetching output features from task %s', mixture_or_task_name) - new_output_features = seqio.get_mixture_or_task( - mixture_or_task_name - ).output_features - return new_output_features - - -def _standardize_output_dirs(output_dir: Union[str, Mapping[str, str]]): - """Standardize the format of output_dirs from user input.""" - logging.info('Standardizing the output_dir: %s', output_dir) - if output_dir is None: - raise ValueError('output_dir is mandatory') - if isinstance(output_dir, str): - output_dirs = {'tpu': output_dir} - else: - output_dirs = dict(output_dir) - if 'cpu' not in output_dirs: - if 'tpu' not in output_dirs: - raise ValueError('output_dir["cpu"] or output_dir["tpu"] is mandatory') - export_version = os.path.basename(output_dirs['tpu']) - if not export_version.isdigit(): - raise ValueError( - 'output_dir must be in the form ${BASE}/${VERSION}, ' - 'where ${VERSION} is an integer. Got a non-numeric ' - f'version {export_version}.' - ) - output_dirs['cpu'] = os.path.join( - os.path.dirname(output_dirs['tpu']) + '_cpu', export_version - ) - logging.info('Result standardized output_dirs: %s', output_dirs) - return output_dirs - - -def create_fake_input(signature: Dict[str, tf.TensorSpec]) -> Any: - """Create all zeros fake inputs according to signature spec. - - Args: - signature: A dictionary of tensor specs that will used for serving. - - Returns: - A pytree with the same structure as signature with all zeros tf.Tensor. - """ - - def _gen_dummy_tensor(ts: tf.TensorSpec): - shape = ts.shape.as_list() - if not all(shape[1:]): - raise ValueError( - 'Only supports polymorphic batch size at leading dimension, got ' - f'{ts} in the input signature.' - ) - if shape and shape[0] is None: - shape[0] = 1 - return tf.zeros(shape, ts.dtype) - - return jax.tree_util.tree_map(_gen_dummy_tensor, signature) - - -def create_batch_polymorphic_shapes( - input_signature, - preprocessor, - *, - create_fake_input_fn=create_fake_input, - overrides=None, -): - """Creates batch polymorhic shapes for jax2tf, and override specific shapes.""" - is_poly_batch = not all( - s.shape.is_fully_defined() - for s in jax.tree_util.tree_leaves(input_signature) - ) - - # All shapes are static. - if not is_poly_batch and not overrides: - return None - - fake_inputs = create_fake_input_fn(input_signature) - features = preprocessor(*fake_inputs) - - # All the features have a leading batch dimension. - shapes = jax.tree_util.tree_map( - lambda _: 'b, ...' if is_poly_batch else None, features - ) - if overrides: - shapes.update(overrides) - return shapes - - -def save( - *, - model: models.BaseTransformerModel, - inference_mode: str | CustomInferenceMode, - restore_checkpoint_cfg: utils.RestoreCheckpointConfig, - exportable_module_cls: Type[ExportableModule], - create_preprocessor_fn: CreatePreprocessorFn = create_preprocessor, - create_inference_function_fn: Callable[ - ..., InferenceFn - ] = create_inference_function, - create_postprocessor_fn: CreatePostprocessorFn = create_postprocessor, - partitioner: Optional[partitioning.BasePartitioner], - create_decoding_state_callback_fn: Optional[ - CreateDecodingStateCallbackFn - ] = None, - output_features: Optional[Mapping[str, seqio.Feature]], - task_feature_lengths: Mapping[str, int], - batch_size: Optional[int], - output_dir: Union[str, Mapping[str, str]], - model_name: str, - warmup_examples: Optional[WarmupExamples] = None, - tokenized_inputs: bool = False, - write_warmup_example_fn: Callable[..., None] = write_warmup_examples, - mixture_or_task_name: Optional[str] = None, - validation_examples: Optional[List[Any]] = None, - native_lowering: bool = True, - native_lowering_platforms: Optional[Sequence[str]] = None, - enable_xla: bool = True, - decode_outputs: Optional[bool] = None, - trailing_shapes: Optional[Mapping[str, Tuple[int, ...]]] = None, - output_vocab_feature_name: Optional[str] = 'targets', - signature_name: str = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY, - create_polymorphic_shapes_fn: Any = create_batch_polymorphic_shapes, -): - """Saves the passed EncoderDecoderModel as a TPU-enabled TF SavedModel. - - Args: - model: - inference_mode: "predict", "score" or a CustomInferenceMode instance. - restore_checkpoint_cfg: Configuration for restoring model from checkpoint. - exportable_module_cls: A configured implementation of ExportableModule. - create_preprocessor_fn: Configurable func. to create the PreprocessorFn. - create_inference_function_fn: Configurable func. to create the InferenceFn. - create_postprocessor_fn: Configurable func. to create the PostprocessorFn. - partitioner: Partitioner, usually for Pjit. - create_decoding_state_callback_fn: Configurable func. to create an optional - decoding.StateCallbackFn. - output_features: Output Features of the task. - task_feature_lengths: Input and target lengths. - batch_size: Batch size for model to process. If None, then batch - polymorphism is invoked. - output_dir: This is either: (a) A path in ${BASE}/${VERSION} format output - the final TPU-converted saved model. The CPU saved model will be saved to - ${BASE}_cpu/${VERSION}, such that "_cpu" is appended to the base path but - the numeric version is preserved. (b) A dict with key 'cpu' and as value - the path to write the CPU model to. - model_name: Name of model, like "/ml/user/half_plus_two". - warmup_examples: Optional list of warmup examples. If proveded, they will be - written in Predict mode to assets.extra. - tokenized_inputs: if True, inputs are expected to be pre-tokenized before - being passed to the Jax2TF converted model, e.g. an int32 tensor of type - [B, L]. If False, inputs is expected to be a string tensor of shape [B]. - We typically set tokenized_inputs to True if tokenization is handled by an - external service. This will disable tokenization in the preprocessor and - postprocessor. - write_warmup_example_fn: a callable which writes a set of warmup examples to - a pbtxt file for use validating a converted model. - mixture_or_task_name: Optional SeqIO task name used to get output features. - In order to set this output_features must be None. - validation_examples: Optional list of validation examples. If proveded, they - will be used to validate the latency and numeric accuracy of the TPU saved - model. - native_lowering: deprecated, always True. - native_lowering_platforms: In conjunction with `native_lowering`, specify - the platform(s) for which to lower the code. Must be a tuple of strings, - including a subset of: 'cpu', 'cuda', 'rocm', 'tpu'. The default - (None), specifies the JAX default backend on the machine where the - lowering is done. - enable_xla: Defaults to true. If false, jax2tf conversion only emits non-XLA - ops. - decode_outputs: Optional bool. If provided, determines whether to decode the - output with the tokenizer, or to leave the output as is. - trailing_shapes: Optional mapping of model feature name to trailing shape, - the `...?` in `(batch_size, seqlen, ...?)`, which is needed to initialize - the model correctly. - output_vocab_feature_name: The vocabulary feature which maps decoded ids to - plain text. For standard T5X models this will always be 'targets', but may - be different or empty for other models. - signature_name: Name of the exported function. - create_polymorphic_shapes_fn: Optional function to create polymorphic shapes - for input tensors to the JAX model function. - """ # fmt: skip - # Always use native serialization. The non-native serialization is deprecated. - del native_lowering - jax.monitoring.record_event('/jax/t5x/export/beacon') - output_dirs = _standardize_output_dirs(output_dir) - del output_dir - - - logging.info('jax.process_count: %s', jax.process_count()) - logging.info('jax.local_devices: %s', jax.local_devices()) # Seems necessary. - logging.info('Creating inference function...') - train_state_initializer = get_train_state_initializer( - model, partitioner, task_feature_lengths, batch_size, trailing_shapes - ) - - output_features = _standardize_output_features( - mixture_or_task_name, output_features - ) - # Get the preprocessor and postprocessor. - - # Non-vanilla seq-to-seq/decoder-only models can have a different - # vocabulary feature or not use a vocabulary feature at all. - output_vocab = None - if output_vocab_feature_name: - output_vocab = output_features[output_vocab_feature_name].vocabulary - - # Handle the new and old create_preprocessor_fn signatures, for backwards - # compatibility. - # TODO(marcrasi): Delete after migrating clients. - if 'batch_size' in inspect.signature(create_preprocessor_fn).parameters: - # New signature. - preprocessor, input_signature = create_preprocessor_fn( - batch_size, output_features, task_feature_lengths, tokenized_inputs - ) # type: ignore - else: - # Old signature. - preprocessor = create_preprocessor_fn( - output_features, task_feature_lengths, tokenized_inputs - ) # type: ignore - input_signature = create_single_tensor_input_signature( - batch_size, task_feature_lengths, tokenized_inputs - ) - - logging.info('Converting inference function...') - - decoding_state_callback_fn = None - if create_decoding_state_callback_fn is not None: - decoding_state_callback_fn = create_decoding_state_callback_fn( - vocab=output_vocab, - call_tf_graph=True, - ) - - model_tf_fn = create_inference_function_fn( - model=model, - train_state_initializer=train_state_initializer, - decoding_state_callback_fn=decoding_state_callback_fn, - partitioner=partitioner, - inference_mode=inference_mode, - enable_jax2tf=True, - enable_xla=enable_xla, - polymorphic_shapes_inputs=create_polymorphic_shapes_fn( - input_signature, preprocessor - ), - native_lowering=True, - native_lowering_platforms=native_lowering_platforms, - ) - - logging.info('Loading parameters from checkpoint...') - params = load_params_from_checkpoint( - restore_checkpoint_cfg=restore_checkpoint_cfg, - train_state_initializer=train_state_initializer, - partitioner=partitioner, - ) - - logging.info('Preparing Module to save...') - if decode_outputs is None: - decode_outputs = not tokenized_inputs - postprocessor = create_postprocessor_fn( - output_vocab, inference_mode, decode_outputs - ) - module = exportable_module_cls( - preproc_tf_fn=preprocessor, - model_tf_fn=model_tf_fn, - postproc_tf_fn=postprocessor, - params=params, - batch_size=batch_size, - ) - module.preproc_func.get_concrete_function(*input_signature) - signatures = { - signature_name: module.__call__.get_concrete_function(*input_signature) - } - logging.info('Saving the CPU model...') - # TODO(b/196260374): Figure out how to set experimental_custom_gradients=True. - options = tf.saved_model.SaveOptions( - experimental_custom_gradients=False, - function_aliases={ - 'tpu_func': module.tpu_func, - }, - ) - tf.saved_model.save( - module, - output_dirs['cpu'], - signatures=signatures, - options=options, - ) - - - if warmup_examples: - if batch_size: - warmup_examples = warmup_examples[:batch_size] - while len(warmup_examples) < batch_size: - if tokenized_inputs: - warmup_examples.append(np.zeros_like(warmup_examples[0]).tolist()) - else: - warmup_examples.append('') - - write_warmup_example_fn( - warmup_examples, - output_dir=output_dirs['cpu'], - model_name=model_name, - batch_sizes=module.export_batch_sizes, - signature_name=signature_name, - ) - - - - # TODO(danielandor): Save the graph.pbtxt for debugging purposes. diff --git a/t5x-main/t5x/fiddle_configs/__init__.py b/t5x-main/t5x/fiddle_configs/__init__.py deleted file mode 100644 index eba3db493bf4b201775a8911a011ee37639d6410..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/fiddle_configs/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - diff --git a/t5x-main/t5x/fiddle_configs/configs/__init__.py b/t5x-main/t5x/fiddle_configs/configs/__init__.py deleted file mode 100644 index eba3db493bf4b201775a8911a011ee37639d6410..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/fiddle_configs/configs/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - diff --git a/t5x-main/t5x/fiddle_configs/configs/finetune.py b/t5x-main/t5x/fiddle_configs/configs/finetune.py deleted file mode 100644 index 9cb1cc69fd8d09f005929b2dbe7b9442202d6752..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/fiddle_configs/configs/finetune.py +++ /dev/null @@ -1,222 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Fiddle-config helpers equivalent to t5x/configs/runs/finetune.gin.""" - -from collections.abc import Mapping -import copy -from typing import Optional, Union - -import fiddle as fdl -import seqio -from t5x import config_utils -from t5x import models -from t5x import partitioning -from t5x import train as t5x_train # import __main__ as train_script -from t5x import trainer -from t5x import utils - - -# Defaults, commonly overridden -DROPOUT_RATE = 0.1 -USE_CACHED_TASKS = True -BATCH_SIZE = 128 - -# Defaults, sometimes overridden -EVAL_STEPS = 20 -EVAL_PERIOD = 1000 - -# Convenience overrides. -EVALUATOR_USE_MEMORY_CACHE = True -EVALUATOR_NUM_EXAMPLES = None # Use all examples in the infer_eval dataset. -JSON_WRITE_N_RESULTS = None # Write all inferences. -# HW RNG is faster than SW, but has limited determinism. -# Most notably it is not deterministic across different -# submeshes. -USE_HARDWARE_RNG = False -# None always uses faster, hardware RNG -RANDOM_SEED = None - -MixtureOrTask = Union[str, seqio.Task, seqio.Mixture] - - -def train( - model: fdl.Buildable[models.BaseTransformerModel], - model_dir: Optional[str], - initial_checkpoint_path: str, - train_steps: int, - mixture_or_task_name: MixtureOrTask, - task_feature_lengths: Mapping[str, int], - eval_steps: int = EVAL_STEPS, - eval_period: int = EVAL_PERIOD, - relative_steps: Optional[int] = None, - random_seed: Optional[int] = RANDOM_SEED, - mixture_or_task_module: Optional[str] = None, - use_hardware_rng: bool = USE_HARDWARE_RNG, - batch_size: int = BATCH_SIZE, - use_cached_tasks: bool = USE_CACHED_TASKS, - json_write_n_results: Optional[bool] = JSON_WRITE_N_RESULTS, - evaluator_num_examples: Optional[bool] = EVALUATOR_NUM_EXAMPLES, - evaluator_use_memory_cache: bool = EVALUATOR_USE_MEMORY_CACHE, -) -> fdl.Buildable: - """Generate a configuration for running T5X `train()` launcher.""" - return fdl.Config( - t5x_train.train, - model=model, - model_dir=model_dir, - train_dataset_cfg=train_dataset_config( - mixture_or_task_name=mixture_or_task_name, - task_feature_lengths=copy.copy(task_feature_lengths), - batch_size=batch_size, - use_cached_tasks=use_cached_tasks, - mixture_or_task_module=mixture_or_task_module, - ), - train_eval_dataset_cfg=train_eval_dataset_config( - mixture_or_task_name=mixture_or_task_name, - task_feature_lengths=copy.copy(task_feature_lengths), - batch_size=batch_size, - use_cached_tasks=use_cached_tasks, - mixture_or_task_module=mixture_or_task_module, - ), - # Does not use `task_feature_lengths`. - infer_eval_dataset_cfg=infer_eval_dataset_config( - mixture_or_task_name=mixture_or_task_name, - batch_size=batch_size, - use_cached_tasks=use_cached_tasks, - mixture_or_task_module=mixture_or_task_module, - ), - checkpoint_cfg=checkpoint_config( - initial_checkpoint_path=initial_checkpoint_path, - ), - partitioner=fdl.Config( - partitioning.PjitPartitioner, - num_partitions=1, - model_parallel_submesh=None, - logical_axis_rules=fdl.Config( - partitioning.standard_logical_axis_rules - ), - ), - trainer_cls=fdl.Partial( - trainer.Trainer, - num_microbatches=None, - learning_rate_fn=fdl.ArgFactory( - utils.create_learning_rate_scheduler, - factors='constant', - base_learning_rate=0.001, - warmup_steps=1000, - ), - ), - total_steps=train_steps, - eval_steps=eval_steps, - eval_period=eval_period, - relative_steps=relative_steps, - random_seed=random_seed, - use_hardware_rng=use_hardware_rng, - summarize_config_fn=config_utils.summarize_fiddle_config, - inference_evaluator_cls=fdl.Partial( - seqio.Evaluator, - logger_cls=[ - fdl.Partial(seqio.PyLoggingLogger), - fdl.Partial(seqio.TensorBoardLogger), - fdl.Partial( - seqio.JSONLogger, write_n_results=json_write_n_results - ), - ], - num_examples=evaluator_num_examples, - use_memory_cache=evaluator_use_memory_cache, - ), - ) - - -def checkpoint_config( - initial_checkpoint_path: str, -) -> fdl.Buildable[utils.CheckpointConfig]: - return fdl.Config( - utils.CheckpointConfig, - restore=fdl.Config( - utils.RestoreCheckpointConfig, - path=initial_checkpoint_path, - mode='specific', - dtype='float32', - ), - save=fdl.Config( - utils.SaveCheckpointConfig, - period=5000, - dtype='float32', - keep=None, # keep all checkpoints, - save_dataset=False, # don't checkpoint dataset state - ), - ) - - -def train_dataset_config( - mixture_or_task_name: MixtureOrTask, - task_feature_lengths: Mapping[str, int], - batch_size: int, - use_cached_tasks: bool, - mixture_or_task_module: Optional[str], -) -> fdl.Buildable[utils.DatasetConfig]: - return fdl.Config( - utils.DatasetConfig, - mixture_or_task_name=mixture_or_task_name, - task_feature_lengths=copy.copy(task_feature_lengths), - split='train', - batch_size=batch_size, - shuffle=True, - seed=None, # use a new seed each run/restart - use_cached=use_cached_tasks, - pack=True, - module=mixture_or_task_module, - ) - - -def train_eval_dataset_config( - mixture_or_task_name: MixtureOrTask, - task_feature_lengths: Mapping[str, int], - batch_size: int, - use_cached_tasks: bool, - mixture_or_task_module: Optional[str], -) -> fdl.Buildable[utils.DatasetConfig]: - return fdl.Config( - utils.DatasetConfig, - mixture_or_task_name=mixture_or_task_name, - task_feature_lengths=copy.copy(task_feature_lengths), - split='validation', - batch_size=batch_size, - shuffle=False, - seed=42, - use_cached=use_cached_tasks, - pack=True, - module=mixture_or_task_module, - ) - - -def infer_eval_dataset_config( - mixture_or_task_name: MixtureOrTask, - batch_size: int, - use_cached_tasks: bool, - mixture_or_task_module: Optional[str], -) -> fdl.Buildable[utils.DatasetConfig]: - return fdl.Config( # pytype: disable=wrong-arg-types # use-fiddle-overlay - utils.DatasetConfig, - mixture_or_task_name=mixture_or_task_name, - task_feature_lengths=None, # compute max - split='validation', - batch_size=batch_size, - shuffle=False, - seed=42, - use_cached=use_cached_tasks, - pack=False, - module=mixture_or_task_module, - ) diff --git a/t5x-main/t5x/fiddle_configs/configs/pretrain.py b/t5x-main/t5x/fiddle_configs/configs/pretrain.py deleted file mode 100644 index da4068fb65f7870aca1a0b994658f19b15252822..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/fiddle_configs/configs/pretrain.py +++ /dev/null @@ -1,167 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Fiddle-config helpers equivalent to t5x/configs/runs/pretrain.gin.""" - -from collections.abc import Mapping -import copy -from typing import Optional, Union - -import fiddle as fdl -import seqio -from t5x import config_utils -from t5x import models -from t5x import partitioning -from t5x import train as t5x_train # import __main__ as train_script -from t5x import trainer -from t5x import utils - - -# Defaults, commonly overridden -USE_CACHED_TASKS = True -BATCH_SIZE = 128 - -# Defaults, sometimes overridden -EVAL_STEPS = 20 -EVAL_PERIOD = 1000 - -# HW RNG is faster than SW, but has limited determinism. -# Most notably it is not deterministic across different -# submeshes. -USE_HARDWARE_RNG = False -# None always uses faster, hardware RNG -RANDOM_SEED = None - -MixtureOrTask = Union[str, seqio.Task, seqio.Mixture] - - -def train( - model: fdl.Buildable[models.BaseTransformerModel], - model_dir: Optional[str], - train_steps: int, - mixture_or_task_name: MixtureOrTask, - task_feature_lengths: Mapping[str, int], - eval_steps: int = EVAL_STEPS, - eval_period: int = EVAL_PERIOD, - relative_steps: Optional[int] = None, - random_seed: Optional[int] = RANDOM_SEED, - mixture_or_task_module: Optional[str] = None, - use_hardware_rng: bool = USE_HARDWARE_RNG, - batch_size: int = BATCH_SIZE, - use_cached_tasks: bool = USE_CACHED_TASKS, -) -> fdl.Buildable: - """Generate a configuration for running T5X `train()` launcher.""" - return fdl.Config( - t5x_train.train, - model=model, - model_dir=model_dir, - train_dataset_cfg=train_dataset_config( - mixture_or_task_name=mixture_or_task_name, - task_feature_lengths=copy.copy(task_feature_lengths), - batch_size=batch_size, - use_cached_tasks=use_cached_tasks, - mixture_or_task_module=mixture_or_task_module, - ), - train_eval_dataset_cfg=train_eval_dataset_config( - mixture_or_task_name=mixture_or_task_name, - task_feature_lengths=copy.copy(task_feature_lengths), - batch_size=batch_size, - use_cached_tasks=use_cached_tasks, - mixture_or_task_module=mixture_or_task_module, - ), - infer_eval_dataset_cfg=None, - checkpoint_cfg=checkpoint_config(), - partitioner=fdl.Config( - partitioning.PjitPartitioner, - num_partitions=1, - model_parallel_submesh=None, - logical_axis_rules=fdl.Config( - partitioning.standard_logical_axis_rules - ), - ), - trainer_cls=fdl.Partial( - trainer.Trainer, - num_microbatches=None, - learning_rate_fn=fdl.ArgFactory( - utils.create_learning_rate_scheduler, - factors='constant * rsqrt_decay', - base_learning_rate=1.0, - # 10k to keep consistent with T5/MTF defaults. - warmup_steps=10000, - ), - ), - total_steps=train_steps, - eval_steps=eval_steps, - eval_period=eval_period, - relative_steps=relative_steps, - random_seed=random_seed, - use_hardware_rng=use_hardware_rng, - summarize_config_fn=config_utils.summarize_fiddle_config, - ) - - -def checkpoint_config() -> fdl.Buildable[utils.CheckpointConfig]: - return fdl.Config( - utils.CheckpointConfig, - restore=fdl.Config(utils.RestoreCheckpointConfig, path=[]), - save=fdl.Config( - utils.SaveCheckpointConfig, - period=1000, - dtype='float32', - keep=None, # keep all checkpoints, - save_dataset=False, # don't checkpoint dataset state - ), - ) - - -def train_dataset_config( - mixture_or_task_name: MixtureOrTask, - task_feature_lengths: Mapping[str, int], - batch_size: int, - use_cached_tasks: bool, - mixture_or_task_module: Optional[str], -) -> fdl.Buildable[utils.DatasetConfig]: - return fdl.Config( - utils.DatasetConfig, - mixture_or_task_name=mixture_or_task_name, - task_feature_lengths=task_feature_lengths, - split='train', - batch_size=batch_size, - shuffle=True, - seed=None, # use a new seed each run/restart - use_cached=use_cached_tasks, - pack=True, - module=mixture_or_task_module, - ) - - -def train_eval_dataset_config( - mixture_or_task_name: MixtureOrTask, - task_feature_lengths: Mapping[str, int], - batch_size: int, - use_cached_tasks: bool, - mixture_or_task_module: Optional[str], -) -> fdl.Buildable[utils.DatasetConfig]: - return fdl.Config( - utils.DatasetConfig, - mixture_or_task_name=mixture_or_task_name, - task_feature_lengths=task_feature_lengths, - split='validation', - batch_size=batch_size, - shuffle=False, - seed=42, - use_cached=use_cached_tasks, - pack=True, - module=mixture_or_task_module, - ) diff --git a/t5x-main/t5x/fiddle_configs/examples/__init__.py b/t5x-main/t5x/fiddle_configs/examples/__init__.py deleted file mode 100644 index eba3db493bf4b201775a8911a011ee37639d6410..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/fiddle_configs/examples/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - diff --git a/t5x-main/t5x/fiddle_configs/examples/t5_1_1.py b/t5x-main/t5x/fiddle_configs/examples/t5_1_1.py deleted file mode 100644 index 7fd2fbfcb8e2abe7f8030dd37fe2abd68b860adf..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/fiddle_configs/examples/t5_1_1.py +++ /dev/null @@ -1,110 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Fiddle versions of examples in t5x/examples/t5_1_1/examples/*.gin.""" -import fiddle as fdl - -import seqio - -# Load task and mixture registrations. -# pylint: disable=unused-import -from t5.data import mixtures -from t5.data import tasks -# pylint: disable=unused-import - -from t5x import config_utils -from t5x import eval as t5x_eval -from t5x import partitioning -from t5x import utils - -from t5x.fiddle_configs.configs import finetune -from t5x.fiddle_configs.configs import pretrain -from t5x.fiddle_configs.models import t5_1_1 - - -def small_wmt_finetune() -> fdl.Buildable: - config = t5_1_1.small_config(dropout_rate=0.0) - model = t5_1_1.model( - config=config, - loss_normalizing_factor=233472, - ) - return finetune.train( - mixture_or_task_name='wmt_t2t_ende_v003', - model_dir=None, # To be set via --fdl_set="model_dir=..." - model=model, - task_feature_lengths={'inputs': 256, 'targets': 256}, - # 1000000 pre-trained steps + 20000 fine-tuning steps. - train_steps=1_020_000, - initial_checkpoint_path=( - 'gs://t5-data/pretrained_models/t5x/' - 't5_1_1_small/checkpoint_1000000' - ), - use_cached_tasks=False, - ) - - -def small_wmt_eval() -> fdl.Buildable: - config = t5_1_1.small_config(dropout_rate=0.0) - model = t5_1_1.model( - config=config, - ) - return fdl.Config( - t5x_eval.evaluate, - model=model, - partitioner=fdl.Config( - partitioning.PjitPartitioner, - num_partitions=1, - ), - dataset_cfg=fdl.Config( # pytype: disable=wrong-arg-types # use-fiddle-overlay - utils.DatasetConfig, - mixture_or_task_name='wmt_t2t_ende_v003', - task_feature_lengths=None, # Auto-computes the max lengths. - split='test', - batch_size=32, - shuffle=False, - seed=42, - ), - inference_evaluator_cls=fdl.Partial( - seqio.Evaluator, - logger_cls=[ - fdl.Partial(seqio.PyLoggingLogger), - fdl.Partial(seqio.TensorBoardLogger), - fdl.Partial(seqio.JSONLogger), - ], - num_examples=None, # Use all examples in the dataset. - use_memory_cache=True, - ), - summarize_config_fn=config_utils.summarize_fiddle_config, - restore_checkpoint_cfg=fdl.Config( # pytype: disable=wrong-arg-types # use-fiddle-overlay - utils.RestoreCheckpointConfig, - path=None, # Set via --fdl_set="restore_checkpoint_cfg.path=..." - mode='specific', - ), - output_dir=None, # Set via --fdl_set="output_dir=..." - ) - - -def small_c4_pretrain() -> fdl.Buildable: - config = t5_1_1.small_config(dropout_rate=0.0) - model = t5_1_1.model( - config=config, - ) - return pretrain.train( - mixture_or_task_name='c4_v220_span_corruption', - model_dir=None, # To be set via --fdl_set="model_dir=..." - model=model, - task_feature_lengths={'inputs': 512, 'targets': 114}, - train_steps=10000, - batch_size=256, - ) diff --git a/t5x-main/t5x/fiddle_configs/examples/t5_1_1_test.py b/t5x-main/t5x/fiddle_configs/examples/t5_1_1_test.py deleted file mode 100644 index 7e34286506ec8da570a16b877d1e59f5a9097e77..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/fiddle_configs/examples/t5_1_1_test.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for t5_1_1 examples.""" - -from absl.testing import absltest -import fiddle as fdl -from t5x import config_utils -from t5x.fiddle_configs.examples import t5_1_1 - - -def _prepare_config(config: fdl.Buildable) -> fdl.Buildable: - config = config_utils.prepare_to_summarize(config) - # Avoid executing config during fdl.Build - return fdl.cast(fdl.Partial, config) - - -class T511Test(absltest.TestCase): - - def test_partial_build_small_wmt_finetune(self): - config = t5_1_1.small_wmt_finetune() - config = _prepare_config(config) - fdl.build(config) - - def test_partial_build_small_wmt_eval(self): - config = t5_1_1.small_wmt_eval() - config = _prepare_config(config) - fdl.build(config) - - -if __name__ == '__main__': - absltest.main() diff --git a/t5x-main/t5x/fiddle_configs/models/__init__.py b/t5x-main/t5x/fiddle_configs/models/__init__.py deleted file mode 100644 index eba3db493bf4b201775a8911a011ee37639d6410..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/fiddle_configs/models/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - diff --git a/t5x-main/t5x/fiddle_configs/models/t5_1_1.py b/t5x-main/t5x/fiddle_configs/models/t5_1_1.py deleted file mode 100644 index e1fa447d8e0bba2235a275eb18bee50ef72f425b..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/fiddle_configs/models/t5_1_1.py +++ /dev/null @@ -1,106 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""T5_1_1 Model Configurations, similar to t5x/examples/t5/t5_1_1/*.gin.""" -from typing import Optional - -import fiddle as fdl -import seqio -from t5x import adafactor -from t5x import models -from t5x import optimizers -from t5x.examples.t5 import network - -Z_LOSS = 0.0001 -LABEL_SMOOTHING = 0.0 -# NOTE: When fine-tuning the public T5 checkpoints (trained in T5 MeshTF) -# the loss normalizing factor should be set to pretraining batch_size * -# target_token_length. -LOSS_NORMALIZING_FACTOR = None - - -def vocabulary() -> fdl.Buildable[seqio.SentencePieceVocabulary]: - return fdl.Config( - seqio.SentencePieceVocabulary, - sentencepiece_model_file=( - 'gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model' - ), - ) - - -def optimizer() -> fdl.Buildable[optimizers.OptimizerDef]: - return fdl.Config( - adafactor.Adafactor, - decay_rate=0.8, - step_offset=0, - logical_factor_rules=fdl.Config( - adafactor.standard_logical_factor_rules, - ), - ) - - -def model( - config: fdl.Buildable[network.T5Config], - z_loss: float = Z_LOSS, - label_smoothing: float = LABEL_SMOOTHING, - loss_normalizing_factor: Optional[float] = LOSS_NORMALIZING_FACTOR, -) -> fdl.Buildable[models.BaseTransformerModel]: - return fdl.Config( - models.EncoderDecoderModel, - module=fdl.Config( # pytype: disable=wrong-arg-types # use-fiddle-overlay - network.Transformer, - config=config, - ), - input_vocabulary=vocabulary(), - output_vocabulary=vocabulary(), - optimizer_def=optimizer(), - z_loss=z_loss, - label_smoothing=label_smoothing, - loss_normalizing_factor=loss_normalizing_factor, - ) - - -def base_config( - dropout_rate: Optional[float], -) -> fdl.Buildable[network.T5Config]: - return fdl.Config( - network.T5Config, - # vocab size rounded to a multiple of 128 for TPU efficiency - vocab_size=32128, - dtype='bfloat16', - emb_dim=768, - num_heads=12, - num_encoder_layers=12, - num_decoder_layers=12, - head_dim=64, - mlp_dim=2048, - mlp_activations=('gelu', 'linear'), - dropout_rate=dropout_rate, - logits_via_embedding=False, - ) - - -def small_config( - dropout_rate: Optional[float], -) -> fdl.Buildable[network.T5Config]: - config = base_config(dropout_rate=dropout_rate) - return fdl.copy_with( - config, - emb_dim=512, - num_heads=6, - num_encoder_layers=8, - num_decoder_layers=8, - head_dim=64, - mlp_dim=1024, - ) diff --git a/t5x-main/t5x/gin_utils.py b/t5x-main/t5x/gin_utils.py deleted file mode 100644 index 64f118a21f6e76bc453b2e7a44de7357743de775..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/gin_utils.py +++ /dev/null @@ -1,167 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Utilities for using gin configurations with T5X binaries.""" - -import os -from typing import Optional, Sequence, Union - -from absl import app -from absl import logging -from clu import metric_writers -import gin -import jax -from t5x import utils -import tensorflow as tf - - -@gin.configurable -def get_gin_config_str(show_provenance: bool = False) -> str: - """Utility function retrieving configuration in string form. - - This is only necessary to to provide a configurable name to toggle the - show_provenance parameter. - - Args: - show_provenance: Flag indicating whether to show (where possible) the - provenance of configuration settings. - - Returns: - Current gin configuration as string. - """ - # The following ensures that existing configs will not fail on old gin version - # that do not have the show_provenance parameter yet and makes this feature - # opt-in. - if show_provenance: - return gin.config_str(show_provenance=True) - else: - return gin.config_str() - - -def parse_gin_flags( - gin_search_paths: Sequence[str], - gin_files: Sequence[str], - gin_bindings: Sequence[str], - skip_unknown: Union[bool, Sequence[str]] = False, - finalize_config: bool = True, -): - """Parses provided gin files override params. - - Args: - gin_search_paths: paths that will be searched for gin files. - gin_files: paths to gin config files to be parsed. Files will be parsed in - order with conflicting settings being overriden by later files. Paths may - be relative to paths in `gin_search_paths`. - gin_bindings: individual gin bindings to be applied after the gin files are - parsed. Will be applied in order with conflicting settings being overriden - by later ones. - skip_unknown: whether to ignore unknown bindings or raise an error (default - behavior). Alternatively, a list of configurable names to skip if unknown. - finalize_config: whether to finalize the config so that it cannot be - modified (default behavior). - """ - # We import t5.data here since it includes gin configurable functions commonly - # used by task modules. - # TODO(adarob): Strip gin from t5.data and remove this import. - # Register .gin file search paths with gin - for gin_file_path in gin_search_paths: - gin.add_config_file_search_path(gin_file_path) - - # Parse config files and bindings passed via flag. - gin.parse_config_files_and_bindings( - gin_files, - gin_bindings, - skip_unknown=skip_unknown, - finalize_config=finalize_config, - ) - logging.info('Gin Configuration:') - for line in get_gin_config_str().splitlines(): - logging.info('%s', line) - - -def rewrite_gin_args(args: Sequence[str]) -> Sequence[str]: - """Rewrite `--gin.NAME=VALUE` flags to `--gin_bindings=NAME=VALUE`.""" - - def _rewrite_gin_arg(arg): - if not arg.startswith('--gin.'): - return arg - if '=' not in arg: - raise ValueError( - "Gin bindings must be of the form '--gin.=', got: " - + arg - ) - # Strip '--gin.' - arg = arg[6:] - name, value = arg.split('=', maxsplit=1) - r_arg = f'--gin_bindings={name} = {value}' - logging.info('Rewritten gin arg: %s', r_arg) - return r_arg - - return [_rewrite_gin_arg(arg) for arg in args] - - -@gin.register -def summarize_gin_config( - model_dir: str, - summary_writer: Optional[metric_writers.MetricWriter], - step: int, -): - """Writes gin config to the model dir and TensorBoard summary.""" - if jax.process_index() == 0: - config_str = get_gin_config_str() - tf.io.gfile.makedirs(model_dir) - # Write the config as JSON. - with tf.io.gfile.GFile(os.path.join(model_dir, 'config.gin'), 'w') as f: - f.write(config_str) - # Include a raw dump of the json as a text summary. - if summary_writer is not None: - summary_writer.write_texts(step, {'config': gin.markdown(config_str)}) - summary_writer.flush() - - -def run(main): - """Wrapper for app.run that rewrites gin args before parsing.""" - utils.run_main( - main, - flags_parser=lambda a: app.parse_flags_with_usage( - list(rewrite_gin_args(a)) - ), - ) # pytype: disable=wrong-arg-types - - -# ====================== Configurable Utility Functions ====================== - - -@gin.configurable -def sum_fn(var1=gin.REQUIRED, var2=gin.REQUIRED): - """sum function to use inside gin files.""" - return var1 + var2 - - -@gin.configurable -def bool_fn(var1=gin.REQUIRED): - """bool function to use inside gin files.""" - return bool(var1) - - -@gin.configurable -def string_split_fn( - text=gin.REQUIRED, separator=gin.REQUIRED, maxsplit=-1, index=None -): - """String split function to use inside gin files.""" - values = text.split(separator, maxsplit) - if index is None: - return values - else: - return values[index] diff --git a/t5x-main/t5x/gin_utils_test.py b/t5x-main/t5x/gin_utils_test.py deleted file mode 100644 index 1805787e70df057f0a91be102ba080feceb7cf46..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/gin_utils_test.py +++ /dev/null @@ -1,61 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for gin_utils.""" - -from absl.testing import absltest -from t5x import gin_utils - - -class GinUtilsTest(absltest.TestCase): - - def test_rewrite_gin_args(self): - test_args = [ - '--gin_file=path/to/file', - 'gin.value=3', - '--gin.value=3', - '--gin.value="3"', - "--gin.value='3'", - '--gin.tricky="key = value"', - '--gin.dict={"foo": 4, "bar": "four"}', - '--gin.gin=bar', - '--gin.scope/foo=bar', - ] - expected_args = [ - '--gin_file=path/to/file', - 'gin.value=3', - '--gin_bindings=value = 3', - '--gin_bindings=value = "3"', - "--gin_bindings=value = '3'", - '--gin_bindings=tricky = "key = value"', - '--gin_bindings=dict = {"foo": 4, "bar": "four"}', - '--gin_bindings=gin = bar', - '--gin_bindings=scope/foo = bar', - ] - self.assertSequenceEqual( - gin_utils.rewrite_gin_args(test_args), expected_args - ) - - def test_rewrite_gin_args_malformed(self): - test_args = ['--gin.value=3', '--gin.test'] - with self.assertRaisesWithLiteralMatch( - ValueError, - "Gin bindings must be of the form '--gin.=', got: " - '--gin.test', - ): - gin_utils.rewrite_gin_args(test_args) - - -if __name__ == '__main__': - absltest.main() diff --git a/t5x-main/t5x/infer.py b/t5x-main/t5x/infer.py deleted file mode 100644 index 47712165751d832ab16291e981ff4fad7d03b8fa..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/infer.py +++ /dev/null @@ -1,868 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# pylint:disable=line-too-long -# pyformat: disable -r"""This script runs inference on a T5X-compatible model. - -""" -# pyformat: enable -# pylint:enable=line-too-long - -import concurrent.futures -import functools -import hashlib -import json -import os -import re -import shutil -import time -from typing import Any, Callable, Iterator, List, Mapping, Optional, Sequence, Tuple, Type - -from absl import logging -from clu import metric_writers -import jax -import jax.numpy as jnp -import numpy as np -import seqio -from t5x import gin_utils -from t5x import models -from t5x import partitioning -from t5x import utils -import tensorflow as tf -from tensorflow.io import gfile -from typing_extensions import Protocol - -# Automatically search for gin files relative to the T5X package. -_DEFAULT_GIN_SEARCH_PATHS = [ - os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -] - -AUTOTUNE = tf.data.experimental.AUTOTUNE - - -class SummarizeConfigFn(Protocol): - - def __call__( - self, - model_dir: str, - summary_writer: Optional[metric_writers.SummaryWriter], - step: int, - ) -> None: - ... - - -class FailFastThreadPoolExecutor(concurrent.futures.ThreadPoolExecutor): - """Wrapper for ThreadPoolExecutor that crashes main thread on exceptions. - - NOTE: this class should be used only from the main thread. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._incomplete_futures: List[concurrent.futures.Future] = [] - - def check_for_exceptions(self, wait: bool = False): - """Raises any exceptions from complete futures on the main thread.""" - still_incomplete_futures = [] - for future in self._incomplete_futures: - try: - exception = future.exception(timeout=0 if wait else None) - except concurrent.futures.TimeoutError: - still_incomplete_futures.append(future) - if exception is not None: - raise exception - - self._incomplete_futures = still_incomplete_futures - - def submit(self, *args, **kwargs) -> concurrent.futures.Future: - """Submit function to threadpool, capturing the returned future.""" - future = super().submit(*args, **kwargs) - self._incomplete_futures.append(future) - self.check_for_exceptions(wait=False) - return future - - def shutdown(self, *args, wait: bool = False, **kwargs): - self.check_for_exceptions(wait=wait) - super().shutdown(*args, **kwargs) - - -def create_task_from_tfexample_file( - paths: Sequence[str], - file_type: str, - inputs_key: str, - targets_key: Optional[str], - features: Mapping[str, seqio.Feature], - task_id: Optional[str] = None, -) -> str: - """Registers ad-hoc Task for file-based dataset of TFExamples. - - Args: - paths: Input file paths; all files should have type `file_type` and contain - binary-serialized TFExample protos. - file_type: Input file type; e.g., 'tfrecord', 'recordio', 'sstable'. For - keyed formats like 'sstable', we ignore the keys and use only the values. - inputs_key: Name of TFExample feature containing the input text for T5X. The - value of this feature should be a UTF8-encoded string. - targets_key: Optional name of a TFExample feature containing the target text - (relevant only in scoring mode). The value of this feature should be a - UTF8-encoded string. - features: Should have entries for keys 'inputs' and (if targets_key is not - None) 'targets', mapping to `seqio.Feature` objects that specify - attributes like vocabulary, add_eos, etc. These attributes are used for - preprocessing and featurizing the input text. - task_id: Task name identifier. By default, it is set to a unique and - deterministic hash id. Overrideable via this argument. - - Returns: - Name of the newly-registered Task. This Task has a split named 'infer' that - contains the preprocessed and featurized input dataset. - """ - # tf.io.gfile.glob supports lists, in contrast to gfile.glob. - files = tf.io.gfile.glob(paths) - if files: - logging.info('Using tfexample files %s', files) - else: - # Fail early if there's something wrong with the input file pattern. - raise ValueError('Missing or invalid paths: %s' % paths) - reader = { - 'tfrecord': tf.data.TFRecordDataset, - }[file_type] - - feature_description = {inputs_key: tf.io.FixedLenFeature([], tf.string)} - if targets_key: - feature_description[targets_key] = tf.io.FixedLenFeature([], tf.string) - - # Create a unique, deterministic task name. - if task_id is None: - task_id = hashlib.md5( - ':'.join(list(paths) + [inputs_key, targets_key or '']).encode() - ).hexdigest()[:10] - - task = seqio.TaskRegistry.add( - name=f'infer_{task_id}', - source=seqio.TFExampleDataSource( - {'infer': paths}, - feature_description=feature_description, - reader_cls=reader, - ), - preprocessors=[ - functools.partial( - seqio.preprocessors.rekey, - key_map={'inputs': inputs_key, 'targets': targets_key}, - ), - seqio.preprocessors.tokenize_and_append_eos, - ], - output_features=features, - ) - - return task.name - - -def merge_chunks_to_file( - output_dir: str, - output_fname: str, - tmp_dir: str, - step: Optional[int], -) -> None: - """Merge the predictions from different chunks into a unified file.""" - logging.info('Merging chunk results.') - # Merge chunks into single file. - chunk_paths = sorted( - gfile.glob(os.path.join(tmp_dir, f'{output_fname}-chunk?????')) - ) - - if not chunk_paths: - raise FileNotFoundError( - 'No chunk results found! One possible explanation is that your ' - 'input did not contain any examples' - ) - - assert int(chunk_paths[-1][-5:]) + 1 == len(chunk_paths), ( - f'Expecting {int(chunk_paths[-1][-5:]) + 1} chunk paths, found ' - f'{len(chunk_paths)}' - ) - output_path = os.path.join(output_dir, output_fname) - del step - with gfile.GFile(output_path, 'wb') as merged: - for chunk_path in chunk_paths: - with gfile.GFile(chunk_path, 'rb') as ef: - shutil.copyfileobj(ef, merged) - logging.info('Results written to %s.', output_path) - - -Inferences = Tuple[Sequence[Any], Mapping[str, Any]] -_Inferences = Inferences # Backwards-compatible alias; used by Colabs - - -def write_inferences_to_file( - path: str, - inferences: Inferences, - task_ds: tf.data.Dataset, - mode: str, - vocabulary: Optional[seqio.Vocabulary] = None, - json_encoder_cls: Type[json.JSONEncoder] = seqio.TensorAndNumpyEncoder, - include_all_inputs: bool = False, - input_fields_to_include: Optional[Sequence[str]] = None, - output_ids: bool = False, -) -> None: - """Write model predictions, along with pretokenized inputs, to JSONL file. - - Args: - path: File path to write to. - inferences: A tuple containing (predictions, aux_values). If mode is - 'predict' then the `predictions` will be token IDs. If it's 'score' then - it'll be a collection of scores. `aux_values` will be an empty dictionary - unless mode is 'predict_with_aux', in which case it'll contain the model's - auxiliary outputs. - task_ds: Original task dataset. Features from task with suffix - `_pretokenized` are added to the outputs. - mode: Prediction mode, either 'predict', 'score' or 'predict_with_aux'. - vocabulary: Task output vocabulary. Only used in `predict` mode in order to - decode predicted outputs into string. - json_encoder_cls: a JSON encoder class used to customize JSON serialization - via json.dumps. - include_all_inputs: if True, will include all model inputs in the output - JSONL file (including raw tokens) in addition to the pretokenized inputs. - input_fields_to_include: List of input fields to include in the output JSONL - file. This list should be None if `include_all_inputs` is set to True. - output_ids: if True, will output the token ID sequence for the output, in - addition to the decoded text. - """ - all_predictions, all_aux_values = inferences - - if mode in ('predict', 'predict_with_aux') and vocabulary is None: - raise ValueError( - 'The `vocabulary` parameter is required in `predict` and ' - '`predict_with_aux` modes' - ) - - def _json_compat(value): - if isinstance(value, bytes): - return value.decode('utf-8') - elif isinstance(value, (jnp.bfloat16, jnp.floating)): - return float(value) - elif isinstance(value, jnp.integer): - return float(value) - elif isinstance(value, (jnp.ndarray, np.ndarray)): - # Flatten array features. - return value.tolist() - else: - return value - - if include_all_inputs and input_fields_to_include is not None: - raise ValueError( - 'include_all_inputs and input_fields_to_include should not be set' - ' simultaneously.' - ) - with gfile.GFile(path, 'w') as f: - for i, inp in task_ds.enumerate().as_numpy_iterator(): - predictions = all_predictions[i] - aux_values = jax.tree.map( - f=lambda v, i=i: v[i], - tree=all_aux_values, - is_leaf=lambda v: isinstance(v, (np.ndarray, list)), - ) - - if include_all_inputs: - inputs = inp - elif input_fields_to_include is not None: - inputs = { - k: v - for k, v in inp.items() - if k in input_fields_to_include - or ( - k.endswith('_pretokenized') - and k[: -len('_pretokenized')] in input_fields_to_include - ) - } - else: - inputs = {k: v for k, v in inp.items() if k.endswith('_pretokenized')} - - json_dict = {} - json_dict['inputs'] = {k: _json_compat(v) for k, v in inputs.items()} - - if mode == 'predict': - assert vocabulary is not None - json_dict['prediction'] = _json_compat( - vocabulary.decode_tf(tf.constant(predictions)).numpy() - ) - if output_ids: - pred = _json_compat(tf.constant(predictions).numpy()) - # Truncate padding tokens. - assert isinstance(pred, list) - pred = pred[: pred.index(0)] if 0 in pred else pred - json_dict['prediction_tokens'] = pred - elif mode == 'score': - json_dict['score'] = _json_compat(predictions) - if aux_values: - json_dict['aux'] = jax.tree.map(_json_compat, aux_values) - elif mode == 'predict_with_aux': - assert vocabulary is not None - json_dict['prediction'] = _json_compat( - vocabulary.decode_tf(tf.constant(predictions)).numpy() - ) - if output_ids: - pred = _json_compat(tf.constant(predictions).numpy()) - # Truncate padding tokens. - pred = pred[: pred.index(0)] if 0 in pred else pred - json_dict['prediction_tokens'] = pred - json_dict['aux'] = jax.tree.map(_json_compat, aux_values) - else: - raise ValueError(f'Invalid mode: {mode}') - json_str = json.dumps(json_dict, cls=json_encoder_cls) - f.write(json_str + '\n') - - -WriteFn = Callable[ - [ - str, - Inferences, - tf.data.Dataset, - str, - Optional[seqio.Vocabulary], - ], - None, -] - -MergeFn = Callable[[str, str, str, Optional[int]], None] - - -def _extract_tokens_and_aux_values(inference_fn_outputs) -> Inferences: - """Extracts tokens and aux scores from a cached dataset.""" - all_aux_values = {} - if isinstance(inference_fn_outputs, tuple): - indices_and_tokens, all_aux_values = inference_fn_outputs - indices, tokens = zip(*indices_and_tokens) - - permutation = np.argsort(indices) - permute = lambda v: [v[permutation[i]] for i in range(len(permutation))] - tokens = permute(tokens) - all_aux_values = jax.tree.map( - f=permute, - tree=all_aux_values, - is_leaf=lambda v: isinstance(v, (np.ndarray, list)), - ) - - else: - indices_and_tokens = inference_fn_outputs - _, tokens = zip(*sorted(indices_and_tokens, key=lambda x: x[0])) - - return tokens, all_aux_values - - -def infer( - *, - mode: str, - model: models.BaseTransformerModel, - dataset_cfg: utils.DatasetConfig, - restore_checkpoint_cfg: utils.RestoreCheckpointConfig, - partitioner: partitioning.BasePartitioner, - output_dir: str, - checkpoint_period: int, - shard_id: int = 0, - num_shards: int = 1, - merge_chunked_results: bool = True, - write_fn: WriteFn = write_inferences_to_file, - checkpoint_ds_iter: bool = True, - train_state_initializer_cls: Type[ - utils.TrainStateInitializer - ] = utils.TrainStateInitializer, - fallback_init_rng: Optional[int] = None, - merge_fn: MergeFn = merge_chunks_to_file, - summarize_config_fn: SummarizeConfigFn = gin_utils.summarize_gin_config, - verify_matching_vocabs_fn: Optional[ - Callable[[utils.DatasetConfig, models.BaseTransformerModel], None] - ] = utils.verify_matching_vocabs, - output_vocab_feature_name: str = 'targets', - file_extension: str = 'jsonl', - keep_aux_as_numpy: bool = False, - use_orbax: bool = True, -): - """Infer function. - - Args: - mode: Either 'predict' to decode targets, 'score' to compute the log - likelihood of given targets, or 'predict_with_aux' for both. - model: The model object to use for inference. - dataset_cfg: Specification for the dataset to infer based on. - restore_checkpoint_cfg: Specification for the model parameter checkpoint to - load. - partitioner: Partitioner for model parameters and data across devices. - output_dir: Path to directory to write temporary files and final results. - checkpoint_period: The intermediate results and dataset iterator will be - checkpointed on each multiple of this number of batches to enable - continuation after a failure. - shard_id: Index of dataset shard for this instance to use if splitting the - work across multiple jobs. - num_shards: Total number of dataset shards to split dataset across. - merge_chunked_results: Whether to merge results of all chunks into a single - json file. - write_fn: Callable function used to serialized and write inferences out to - files. - checkpoint_ds_iter: if True, will checkpoint the dataset iterator every - `checkpoint_period` to enable faster restore. This must be disabled for - certain datasets, for example since stateful iterators (e.g. from - seqio.FunctionTask) cannot be checkpointed. - train_state_initializer_cls: t5x.utils.TrainStateInitializer class for - initializing partitioned TrainState from checkpoints or scratch. - fallback_init_rng: A random seed used for parameter initialization during - model re-loading when utils.RestoreCheckpointConfig.fallback_to_scratch is - set to True. If None, parameter initialization is not allowed during model - loading and having fallback_to_scratch enabled will result in an error. - merge_fn: Callable function used to merge inferences from multiple files. - summarize_config_fn: A function that takes in the model directory, an - optional SummaryWriter, and the step number, and writes a summary of the - configuration. SummaryWriter will be None in most cases. - verify_matching_vocabs_fn: Function to validate whether the task vocabulary - matches the model vocabulary. Should raise an exception on error. - output_vocab_feature_name: The name of the feature corresponding to the - output vocabulary. - file_extension: str. file extension used for file names - keep_aux_as_numpy: bool. whether to leave aux values as numpy arrays; can be - used to save space when saving bfloat16s - use_orbax: if True, uses Orbax for checkpointing. Experimental feature. - """ - jax.monitoring.record_event('/jax/t5x/infer/beacon') - logging.info('Process ID: %d', jax.process_index()) - - # Only allow `shard_id` 0 to write config summary, since the config summary - # does NOT depend on `shard_id`. - if shard_id == 0: - summarize_config_fn(model_dir=output_dir, summary_writer=None, step=0) - - if mode not in ('predict', 'score', 'predict_with_aux'): - raise ValueError( - "`mode` must be one of 'predict', 'score' or 'predict_with_aux'. " - f"Got '{mode}'" - ) - - # Remove double-slashes in directory path to avoid inconsistencies. - output_dir = re.sub(r'(? 1: - raise app.UsageError('Too many command-line arguments.') - - if FLAGS.tfds_data_dir: - seqio.set_tfds_data_dir_override(FLAGS.tfds_data_dir) - - - if config_utils.using_fdl(): - config = config_utils.config_with_fiddle(infer) - shard_id = FLAGS.shard_id - if shard_id is not None: - config.shard_id = shard_id - infer_with_fiddle = fdl.build(config) - if shard_id == 0: - config_utils.direct_summarize_fiddle_config( - model_dir=infer_with_fiddle.output_dir, - summary_writer=None, - step=0, - get_current_fiddle_config=lambda: infer_with_fiddle, - ) - infer_with_fiddle() - else: - # Create gin-configurable version of `infer`. - infer_using_gin = gin.configurable(infer) - - gin_utils.parse_gin_flags( - # User-provided gin paths take precedence if relative paths conflict. - FLAGS.gin_search_paths + _DEFAULT_GIN_SEARCH_PATHS, - FLAGS.gin_file, - FLAGS.gin_bindings, - ) - - # See http://yaqs/7882016229479677952 for further gin-config discussion. - def _get_gin_parameter(key: str) -> Any: - value = gin.query_parameter(key) - if isinstance(value, gin.config.ConfigurableReference): - if value.evaluate: - return value.scoped_configurable_fn() - return value.scoped_configurable_fn - return value - - shard_id = ( - FLAGS.shard_id - if FLAGS.shard_id is not None - else _get_gin_parameter('infer.shard_id') - ) - if shard_id == 0: - gin_utils.summarize_gin_config( - model_dir=_get_gin_parameter('infer.output_dir'), - summary_writer=None, - step=0, - ) - if FLAGS.shard_id is not None: - # We fall back to this flag since XM does not support sweeps over flags - # with '.' in them (it treats them like nested dictionaries). - # TODO(adarob): Figure out a workaround so we can deprecate this flag. - infer_using_gin(shard_id=FLAGS.shard_id) - else: - infer_using_gin() - - - config_utils.run(main) diff --git a/t5x-main/t5x/interactive_model.py b/t5x-main/t5x/interactive_model.py deleted file mode 100644 index ea453d7593130529223728c03bf033fc3f492864..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/interactive_model.py +++ /dev/null @@ -1,1279 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""InteractiveModel class for use in T5X Colabs. - -The InteractiveModel can be used to run training, inference, and evaluation on -natural text inputs and targets. - -""" - -import abc -from collections.abc import Mapping, Sequence -import enum -import functools -import inspect -import itertools -import logging -import os -import re -from typing import Any, Callable, Iterator, Optional, Tuple, Union - -import clu.data.dataset_iterator -import jax -from jax import random -from jax.experimental import multihost_utils -import numpy as np -import seqio -from t5x import checkpoints -from t5x import models -from t5x import partitioning -from t5x import trainer as trainer_lib -from t5x import utils -from t5x.infer import _extract_tokens_and_aux_values -from t5x.infer import _Inferences -import tensorflow as tf -import tensorflow_datasets as tfds - -BatchesType = Union[ - Sequence[Mapping[str, str]], Sequence[Sequence[Mapping[str, str]]] -] - - -class InferenceType(enum.Enum): - PREDICT_WITH_AUX = 1 - SCORE = 2 - - -class T5XScriptType(enum.Enum): - FINETUNING = 1 - INFERENCE = 2 - EVALUATION = 3 - PRETRAINING = 4 - - -class InteractiveModel(abc.ABC): - """Wrapper around T5X components to enable interactive train/infer/eval.""" - - def __init__( - self, - batch_size: int, - task_feature_lengths: Mapping[str, int], - output_dir: str, - partitioner: partitioning.BasePartitioner, - model: models.BaseTransformerModel, - dtype: Optional[str], - restore_mode: str, - checkpoint_path: str, - input_shapes: Mapping[str, utils.Array], - input_types: Optional[Mapping[str, utils.DType]] = None, - init_random_seed: int = 42, - add_eos: bool = True, - eval_names: Optional[Sequence[str]] = None, - ): - """Init function. - - Configures the output directory, RNGs, and partitioner given the provided - arguments. - - Args: - batch_size: number of examples per batch for training, inference, and - evaluation. - task_feature_lengths: dictionary mapping feature key to maximum length - (int) for that feature. If feature is longer than this length after - preprocessing, the feature will be truncated. May be set to None to - avoid truncation. - output_dir: Path to directory where we will write temporary files and - final results. - partitioner: the partitioner that defines how we divide and replicate - machine learning model parameters, activations, and data across the - accelerator devices (TPU/GPU). See https://github.com/google-research/t5x/blob/main/docs/usage.md/partitioning for - details. - model: the model object to use for training, inference, and evaluation. - dtype: The dtype to restore ('float32' or 'bfloat16'), or None to load as - saved. - restore_mode: One of 'specific', 'latest', or 'all'. `specific` loads the - checkpoint specified by `path`. `latest` loads the most recent - checkpoint in the directory specified by `path`. `all` sequentially - loads all of checkpoints in the directory `path`. - checkpoint_path: Path(s) to checkpoint to restore from or directory - (depending on `restore_mode`). - input_shapes: a mapping from key to array shape for each feature in the - global (unsharded) input batch. - input_types: a mapping from key to array type for each feature in the - global (unshared) input batch. If not provided, the type is assumed to - be `jnp.float32`. - init_random_seed: the random seed used to initialize all RNGs. - add_eos: whether or not to add the EOS token to inputs/targets. - eval_names: names of evaluation datasets, which must match the keys of the - mapping passed to trainer's `eval` method. - - Raises: - ValueError: the partitioner has an incorrect submesh, or the checkpoint - restore function returned a sequence of TrainStates, when it should have - returned a single TrainState. - """ - self._batch_size = batch_size - self._task_feature_lengths = task_feature_lengths - self._cached_infer_fns = {} - # -------------------------------------------------------------------------- - # Configure the output directory - # -------------------------------------------------------------------------- - self._output_dir = output_dir - # Remove double-slashes in directory path to avoid inconsistencies. - self._output_dir = re.sub(r"(? _Inferences: - """Infer function. - - Args: - mode: Either 'score' to compute the log likelihood of given targets, or - 'predict_with_aux' to score and decode targets. - examples: examples that should be transformed into a tf.data.Dataset. The - examples can either take the form of a string (ex: a single input for - inference), or a dictionary mapping "input"/"target" to a string - containing that element. - preprocessors: list(callable), an optional list of functions that receive - a tf.data.Dataset and return a tf.data.Dataset. These will be executed - sequentially and the final dataset must include features matching - `self._features`. - **inference_kwargs: additional keyword arguments to pass to the inference - function (e.g., `model.predict_batch_with_aux` or `score_batch`). - - Returns: - Returns a tuple of predictions/scores and any auxiliary values. - """ - # -------------------------------------------------------------------------- - # Parse Mode - # -------------------------------------------------------------------------- - if mode == InferenceType.PREDICT_WITH_AUX: - infer_step = self._model.predict_batch_with_aux - elif mode == InferenceType.SCORE: - infer_step = self._model.score_batch - else: - raise ValueError( - "Mode must be `predict_with_aux`, or `score`," - f" but instead was {mode}." - ) - key_array = seqio.utils.flatten_dict(inference_kwargs) - key_array["mode"] = mode - infer_fn_key = tuple(key_array.items()) - if infer_fn_key not in self._cached_infer_fns: - self._cached_infer_fns[infer_fn_key] = utils.get_infer_fn( - infer_step=functools.partial(infer_step, **inference_kwargs), - batch_size=self._batch_size, - train_state_axes=self._train_state_initializer.train_state_axes, - partitioner=self._partitioner, - ) - infer_fn = functools.partial( - self._cached_infer_fns[infer_fn_key], - train_state=self._train_state, - ) - - # -------------------------------------------------------------------------- - # Construct a dataset and dataset iterator. - # -------------------------------------------------------------------------- - dataset = get_dataset_from_natural_text_examples( - examples, - preprocessors=preprocessors, - task_feature_lengths=self._task_feature_lengths, - features=self._features, - ) - model_dataset = self._feature_converter( - dataset, task_feature_lengths=self._task_feature_lengths - ) - # Zip task and model features. - infer_dataset = tf.data.Dataset.zip((dataset, model_dataset)) - # Create batches and index them. - infer_dataset = infer_dataset.padded_batch( - self._batch_size, drop_remainder=False - ).enumerate() - infer_dataset_iter: Iterator[Tuple[int, Any]] = iter( - infer_dataset.prefetch(tf.data.experimental.AUTOTUNE) - ) - - # -------------------------------------------------------------------------- - # Run inference - # -------------------------------------------------------------------------- - # Main Loop over "batches". - all_inferences = [] - all_aux_values = {} - for chunk, chunk_batch in infer_dataset_iter: - # Load the dataset for the next chunk. We can't use `infer_dataset_iter` - # directly since `infer_fn` needs to know the exact size of each chunk, - # which may be smaller for the final one. - chunk_dataset = tf.data.Dataset.from_tensor_slices(chunk_batch) - chunk_dataset.cache().prefetch(tf.data.experimental.AUTOTUNE) - - # Unzip chunk dataset in to pretokenized and model datasets. - task_dataset = chunk_dataset.map( - lambda p, m: p, num_parallel_calls=tf.data.experimental.AUTOTUNE - ) - model_dataset = chunk_dataset.map( - lambda p, m: m, num_parallel_calls=tf.data.experimental.AUTOTUNE - ) - - # Get a chunk-specific RNG key. - chunk_rng = jax.random.fold_in(jax.random.PRNGKey(0), chunk) - - inferences = _extract_tokens_and_aux_values( - infer_fn(model_dataset.enumerate(), rng=chunk_rng) - ) - - predictions, aux_values = inferences - accumulated_inferences = [] - for idx, inputs in task_dataset.enumerate().as_numpy_iterator(): - prediction = predictions[idx] - # Decode predictions if applicable. - if mode == InferenceType.PREDICT_WITH_AUX: - prediction = ( - self._features["targets"] - .vocabulary.decode_tf(tf.constant(prediction)) - .numpy() - ) - accumulated_inferences.append((inputs, prediction)) - all_inferences += accumulated_inferences - # Accumulate aux values over batches. - if not all_aux_values: - all_aux_values = aux_values - else: - for key, values in aux_values.items(): - all_aux_values[key] += values - - return all_inferences, all_aux_values - - def predict_with_aux( - self, examples: Sequence[Union[str, dict[str, str]]] - ) -> _Inferences: - """Predict with auxiliary values method.""" - # By default, only tokenize and append EOS. - preprocessors = [ - seqio.preprocessors.tokenize, - seqio.preprocessors.append_eos, - ] - return self.infer_with_preprocessors( - mode=InferenceType.PREDICT_WITH_AUX, - examples=examples, - preprocessors=preprocessors, - ) - - def score( - self, examples: Sequence[Union[str, dict[str, str]]] - ) -> Sequence[Any]: - """Score method.""" - # By default, only tokenize and append EOS. - preprocessors = [ - seqio.preprocessors.tokenize, - seqio.preprocessors.append_eos, - ] - # Ignore auxiliary values. - scores, _ = self.infer_with_preprocessors( - mode=InferenceType.SCORE, examples=examples, preprocessors=preprocessors - ) - return scores - - def _compute_metrics( - self, - targets: Sequence[Any], - predictions: Sequence[Any], - aux_values: Sequence[Any], - scores: Sequence[Any], - predict_metric_fns: Sequence[seqio.dataset_providers.MetricFnCallable], - predict_with_aux_metric_fns: Sequence[ - seqio.dataset_providers.MetricFnCallable - ], - score_metric_fns: Sequence[seqio.dataset_providers.MetricFnCallable], - ): - """Computes the metrics specified in the metric_fns lists.""" - # Only compute metrics once - if jax.process_index() != 0: - return {} - - def compute_metrics_fn(): - task_metrics = [] - if predict_metric_fns: - task_metrics.extend([ - metric_fn(targets, predictions) for metric_fn in predict_metric_fns - ]) - if predict_with_aux_metric_fns: - task_metrics.extend([ - metric_fn(targets, predictions, aux_values) - for metric_fn in predict_with_aux_metric_fns - ]) - if score_metric_fns: - is_tuple = isinstance(scores, tuple) - if (not is_tuple and len(targets) != len(scores)) or ( - is_tuple and len(targets) != len(scores[0]) - ): - raise ValueError( - f"len(targets)({len(targets)}) != " - f"len(output_scores)({len(scores)})" - ) - task_metrics.extend( - [metric_fn(targets, scores) for metric_fn in score_metric_fns] - ) - - all_metrics = {} - for k, v in itertools.chain(*[m.items() for m in task_metrics]): - if k in all_metrics: - raise ValueError(f"Duplicate metric key '{k}' in Task.") - all_metrics[k] = v - return all_metrics - - if not tf.executing_eagerly(): - - def wrap_graph(fn): - graph = tf.compat.v1.get_default_graph() - - def wrapped_fn(): - with graph.as_default(): - return fn() - - return wrapped_fn - - compute_metrics_fn = wrap_graph(compute_metrics_fn) - - all_metrics = compute_metrics_fn() - # Wait until computations are done before continuing. - utils.sync_global_devices("Completed.") - return all_metrics - - def evaluate( - self, - examples: Sequence[Union[str, dict[str, str]]], - metric_fns: Sequence[seqio.dataset_providers.MetricFnCallable], - ) -> Mapping[Any, Any]: - """Evaluation function. - - Args: - examples: examples that should be transformed into a tf.data.Dataset. The - examples can either take the form of a string (ex: a single input for - inference), or a dictionary mapping "input"/"target" to a string - containing that element. - metric_fns: list(callable), an optional list of metric functions with a - signature that matches one of three possible forms: - (targets, scores) - - Note that `scores` refers to the score the model assigned the target - sequence, given the input. - (targets, predictions) - (targets, - predictions, aux_values) - Note that `aux_values` refers to a dictionary - of auxiliary values that the model assigned to each sequence. - - Returns: - Mapping of metrics names to metrics values. - """ - # By default, only tokenize and append EOS. - preprocessors = [ - seqio.preprocessors.tokenize, - seqio.preprocessors.append_eos, - ] - return self.evaluate_with_preprocessors( - examples=examples, - preprocessors=preprocessors, - metric_fns=metric_fns, - postprocessor=None, - ) - - def evaluate_with_preprocessors( - self, - examples: Sequence[dict[str, str]], - preprocessors: Sequence[Callable[..., tf.data.Dataset]], - metric_fns: Sequence[seqio.dataset_providers.MetricFnCallable], - postprocessor: Optional[Callable[..., Any]] = None, - ) -> Mapping[Any, Any]: - """Evaluation function. - - Args: - examples: examples that should be transformed into a tf.data.Dataset. The - examples must take the form of a dictionary mapping "input"/"target" to - a string containing that element. - preprocessors: list(callable), an optional list of functions that receive - a tf.data.Dataset and return a tf.data.Dataset. These will be executed - sequentially and the final dataset must include features matching - `self._features`. - metric_fns: list(callable), an optional list of metric functions with a - signature that matches one of three possible forms: - (targets, scores) - - Note that `scores` refers to the score the model assigned the target - sequence, given the input. - (targets, predictions) - (targets, - predictions, aux_values) - Note that `aux_values` refers to a dictionary - of auxiliary values that the model assigned to each sequence. - postprocessor: callable, an optional function that receives decoded model - outputs and converts them to a form that is ready for evaluation using - the metric functions in `metric_fns`. - - Returns: - Mapping of metrics names to metrics values. - """ - # -------------------------------------------------------------------------- - # Parse Metrics functions - # -------------------------------------------------------------------------- - predict_metric_fns = [] - predict_with_aux_metric_fns = [] - score_metric_fns = [] - for metric_fn in metric_fns: - pos_args = tuple( - key - for key, param in inspect.signature(metric_fn).parameters.items() - if param.default == inspect.Parameter.empty - ) - if pos_args == ("targets", "scores"): - score_metric_fns.append(metric_fn) - elif pos_args == ("targets", "predictions"): - predict_metric_fns.append(metric_fn) - elif pos_args == ("targets", "predictions", "aux_values"): - predict_with_aux_metric_fns.append(metric_fn) - else: - raise ValueError( - "Metric functions must have positional arguments matching either " - "('targets', 'scores'), ('targets', 'predictions') or " - "('targets', 'predictions', 'aux_values'). " - f"Got: {pos_args}" - ) - - # ------------------------------------------------------------------------ - # Get targets, predictions, and scores - # ------------------------------------------------------------------------ - dataset = get_dataset_from_natural_text_examples( - examples, - preprocessors=preprocessors, - task_feature_lengths=self._task_feature_lengths, - features=self._features, - ) - - # Get targets. - def postprocess_fn(decoded_model_output: Any, **postprocess_kwargs) -> Any: - """Returns the model output after applying the postprocess function.""" - if postprocessor: - return postprocessor(decoded_model_output, **postprocess_kwargs) - return decoded_model_output - - targets = [] - for ex in tfds.as_numpy(dataset): - targets.append( - postprocess_fn( - decoded_model_output=ex["targets_pretokenized"], - example=ex, - is_target=True, - ) - ) - - # Get predictions. - predictions = [] - if predict_with_aux_metric_fns or predict_metric_fns: - predictions, aux_values = self.infer_with_preprocessors( - mode=InferenceType.PREDICT_WITH_AUX, - examples=examples, - preprocessors=preprocessors, - ) - predictions = [ - prediction.decode("utf-8") for example, prediction in predictions - ] - # Get scores. - scores = [] - if score_metric_fns: - scores, _ = self.infer_with_preprocessors( - mode=InferenceType.SCORE, - examples=examples, - preprocessors=preprocessors, - ) - scores = [score for example, score in scores] - - return self._compute_metrics( - targets, - predictions, - aux_values, - scores, # pytype: disable=wrong-arg-types # mapping-is-not-sequence - predict_metric_fns, - predict_with_aux_metric_fns, - score_metric_fns, - ) - - def train_loop( - self, - num_steps: int, - eval_period: Optional[int] = 1, - train_batches: Optional[BatchesType] = None, - predict_batches: Optional[BatchesType] = None, - score_batches: Optional[BatchesType] = None, - eval_batches: Optional[BatchesType] = None, - metrics_fns: Optional[ - Sequence[seqio.dataset_providers.MetricFnCallable] - ] = None, - ): - """Runs training, inference, and evaluation for `num_steps`. - - It should be noted that there are many different possible variants of the - `train_loop` function that a user might want to use. The primary goal of the - `train_loop` function is not to cover all the potential training loop - variants that a user may want; rather, the goal is to demonstrate how the - user could stack the `InteractiveModel` train, predict, score, and evaluate - methods. - - Args: - num_steps: the number of steps to run for training, inference, and - evaluation. - eval_period: specifies how many steps to take between - inference/evaluation. - train_batches: an optional list of batches that we should run training on. - If no batches are provided, then training will be skipped. If a single - batch is provided, we will repeat training on this batch for - `num_steps`. - predict_batches: an optional list of batches that we should get - predictions for. If no batches are provided, then predicting will be - skipped. If a single batch is provided, we will repeatedly get - predictions on this batch for `num_steps`. - score_batches: an optional list of batches that we should score. If no - batches are provided, then scoring will be skipped. If a single batch is - provided, we will repeatedly score this batch for `num_steps`. - eval_batches: an optional list of batches that we should run eval on. If - no batches are provided, then evaluation will be skipped. If a single - batch is provided, we will repeatedly evaluate this batch for - `num_steps`. - metrics_fns: list(callable), an optional list of metric functions with a - signature that matches one of three possible forms: - (targets, scores) - - Note that `scores` refers to the score the model assigned the target - sequence, given the input. - (targets, predictions) - (targets, - predictions, aux_values) - Note that `aux_values` refers to a dictionary - of auxiliary values that the model assigned to each sequence. - - Returns: - Predictions, scores, and metrics for the final step of the training loop. - """ - # Ensure all batches are `num_steps` in length - train_batches = _get_equal_length_batches(train_batches, num_steps) - - predictions = None - scores = None - metrics = None - for step_num, train_batch in enumerate(train_batches): - if train_batch: - self.train_step(train_batch) - # Run inference/evaluation every `eval_period` steps. - if step_num % eval_period == 0: - # Run on all batches for inference/evaluation. - if predict_batches: - for predict_batch in predict_batches: - predictions, _ = self.predict_with_aux(predict_batch) # pytype: disable=wrong-arg-types # mapping-is-not-sequence - if score_batches: - for score_batch in score_batches: - scores = self.score(score_batch) # pytype: disable=wrong-arg-types # mapping-is-not-sequence - if eval_batches: - for eval_batch in eval_batches: - metrics = self.evaluate(eval_batch, metrics_fns) # pytype: disable=wrong-arg-types # mapping-is-not-sequence - return predictions, scores, metrics - - -def get_dataset_from_natural_text_examples( - examples: Sequence[Union[str, dict[str, str]]], - preprocessors: Sequence[Callable[..., tf.data.Dataset]], - task_feature_lengths: Mapping[str, int], - features: Mapping[str, Any], -) -> tf.data.Dataset: - """Returns a tf.data.Dataset from a list of examples. - - Args: - examples: a single batch of examples that should be transformed into a - tf.data.Dataset. The examples can either take the form of a string (ex: a - single input for inference), or a dictionary mapping "input"/"target" to a - string containing that element. - preprocessors: an optional list of functions that receive a tf.data.Dataset - and return a tf.data.Dataset. These will be executed sequentially and the - final dataset must include features matching `self._features`. - task_feature_lengths: dictionary mapping feature key to maximum length (int) - for that feature. If feature is longer than this length after - preprocessing, the feature will be truncated. May be set to None to avoid - truncation. - features: dictionary defining what features should be present in all - examples. - - Returns: - A tf.data.Dataset. - """ - # ------------------------------------------------------------------------ - # Construct a `tf.data.Dataset` from the provided examples - # ------------------------------------------------------------------------ - merged_examples = {"inputs": [], "targets": []} - for example in examples: - # If the provided example is just a string, add an empty target string - if isinstance(example, dict): - example_dict = example - else: - example_dict = {"input": example, "target": ""} - merged_examples["inputs"].append(example_dict["input"]) - merged_examples["targets"].append(example_dict["target"]) - dataset = tf.data.Dataset.from_tensor_slices(merged_examples) - - # Define `ShardInfo` that doesn't shard the data pipeline. - shard_info = seqio.ShardInfo(0, 1) - dataset = dataset.shard(shard_info.num_shards, shard_info.index) - - # ------------------------------------------------------------------------ - # Preprocess data - # ------------------------------------------------------------------------ - for prep_fn in preprocessors: - # prep_fn must not rely on variable length keyword args such as **kwargs. - fn_args = set(inspect.signature(prep_fn).parameters.keys()) - kwargs = {} - if "sequence_length" in fn_args: - kwargs["sequence_length"] = task_feature_lengths - if "output_features" in fn_args: - kwargs["output_features"] = features - dataset = prep_fn(dataset, **kwargs) - - def _validate_preprocessing(dataset: tf.data.Dataset) -> tf.data.Dataset: - """Validates preprocessed dataset, raising Exceptions if needed. - - Args: - dataset: a tf.data.Dataset to validate. - - Returns: - a validated tf.data.Dataset. - - Raises: - ValueError: dataset has missing feature or the incorrect type/rank for a - feature. - """ - actual_specs = dataset.element_spec - for feat, feat_spec in features.items(): - if feat not in actual_specs: - if feat_spec.required: - raise ValueError( - "Task dataset is missing expected output feature after " - f"preprocessing: {feat}" - ) - else: - # It's ok that this feature does not exist. - continue - actual_spec = actual_specs[feat] - if feat_spec.dtype != actual_spec.dtype: - raise ValueError( - f"Task dataset has incorrect type for feature '{feat}' after " - f"preprocessing: Got {actual_spec.dtype.name}, expected " - f"{feat_spec.dtype.name}" - ) - if feat_spec.rank != actual_spec.shape.rank: - raise ValueError( - f"Task dataset has incorrect rank for feature '{feat}' after " - f"preprocessing: Got {actual_spec.shape.rank}, expected " - f"{feat_spec.rank}" - ) - - return dataset - - dataset = _validate_preprocessing(dataset) - dataset = seqio.utils.trim_dataset(dataset, task_feature_lengths, features) - return dataset.prefetch(tf.data.experimental.AUTOTUNE) - - -def _get_equal_length_batches( - batches: BatchesType, length: int -) -> Sequence[Any]: - """Produces a list of batches that is `length` batches long. - - Given a single batch, repeat the batch `length` times. - - Given a list of batches, either repeat the batches to get `length` total - batches or take the first 'length' batches. - - Args: - batches: either a single batch of examples, or a list of batches. - length: the total number of batches that should be present in the final - list. - - Returns: - A list of batches. - """ - # Given a list of batches, return a list of batches that is `length` long, - # either by repeating the batches or taking the first `length` batches - if not batches: - return [None] * length - if isinstance(batches[0], Mapping): - return [batches for i in range(length)] - if len(batches) < length: - batches = batches * (length // len(batches)) - # If multiple batches are provided, only use the first `length` batches. - logging.warning( - "We will only use the first %s batches provided for training.", length - ) - return batches[:length] - - -def get_batches_from_seqio( - task_or_mixture_name: str, - split: str, - batch_size: int, - num_batches: int, - get_pretokenized_examples: bool = True, - sequence_length: Optional[Mapping[str, int]] = None, - **get_dataset_kwargs, -) -> Sequence[Sequence[Mapping[str, str]]]: - """Returns a batch of examples from a provided SeqIO task. - - Args: - task_or_mixture_name: the SeqIO task/mixture to read data from. - split: the split of the SeqIO task/mixture to read data from. - batch_size: how many examples should be in each batch. - num_batches: the total number of batches to return. - get_pretokenized_examples: a bool, where True indicates that we should - return the natural text (pre-tokenization) inputs and targets. Default to - True in order to make the examples easy to debug/inspect. - sequence_length: dictionary mapping feature key to maximum length (int) for - that feature. Used by SeqIO to get the dataset. - **get_dataset_kwargs: any additional arguments that should be passed to the - SeqIO `get_dataset()` call. - - Returns: - A sequence of batches, where each batch is a sequence of examples. Each - example is a dictionary mapping 'input' and 'target' to the corresponding - values for a single example. - """ - task_or_mixture = seqio.get_mixture_or_task(task_or_mixture_name) - total_examples_requested = batch_size * num_batches - dataset = task_or_mixture.get_dataset( - sequence_length=sequence_length, split=split, **get_dataset_kwargs - ) - - all_batches = [] - current_batch = [] - input_key = "inputs_pretokenized" if get_pretokenized_examples else "inputs" - target_key = ( - "targets_pretokenized" if get_pretokenized_examples else "targets" - ) - total_examples_seen = 0 - # It should be noted that we could replace the following loop with tf.Dataset - # operations (like - # `list(dataset.batch(batch_size).take(num_batches).as_numpy_iterator())`), - # but this would require us to pad batches first or represent the token IDs as - # ragged tensors. These approaches are currently overkill for the - # InteractiveModel, but may be investigated in the future. - dataset = dataset.take(total_examples_requested) - for idx, element in enumerate(dataset.as_numpy_iterator()): - total_examples_seen += 1 - if idx >= total_examples_requested: - # Because we force `num_examples_requested` to be a multiple of - # `batch_size`, this should enforce that the last batch always has the - # same number of examples as all other batches. - break - - example_input = element[input_key] - example_target = element[target_key] - if not get_pretokenized_examples: - example_input = example_input.tolist() - example_target = example_target.tolist() - current_example = {"input": example_input, "target": example_target} - current_batch.append(current_example) - - # If we've collected `batch_size` examples, save the current batch and start - # a new batch. - if len(current_batch) == batch_size: - all_batches.append(current_batch) - current_batch = [] - - if total_examples_seen < total_examples_requested: - raise ValueError( - "Not enough examples in Task/Mixture. User requested " - f"{num_batches} batches of size {batch_size} for a total " - f"of {total_examples_requested} examples. Only " - f"{total_examples_seen} available in " - "Task/Mixture." - ) - - return all_batches - - -def get_seqio_task_from_examples( - task_name: str, - interactive_model: InteractiveModel, - examples: Sequence[Union[str, dict[str, str]]], - preprocessors: Sequence[Callable[..., tf.data.Dataset]], - metric_fns: Optional[ - Sequence[seqio.dataset_providers.MetricFnCallable] - ] = None, - add_to_registry: bool = True, -) -> Union[seqio.Task, seqio.Mixture]: - """Registers and returns a SeqIO task from the provided inputs. - - This function will be used to graduate people to the T5X/SeqIO-based - train/infer/eval scripts. - - Args: - task_name: the name of the SeqIO task to be created and registered. - interactive_model: an instance of the InteractiveModel. - examples: a single batch of examples that should be transformed into a - tf.data.Dataset. The examples can either take the form of a string (ex: a - single input for inference), or a dictionary mapping "input"/"target" to a - string containing that element. - preprocessors: an optional list of functions that receive a tf.data.Dataset - and return a tf.data.Dataset. These will be executed sequentially and the - final dataset must include features matching `self._features`. - metric_fns: list(callable), an optional list of metric functions with a - signature that matches one of three possible forms: - (targets, scores) - - Note that `scores` refers to the score the model assigned the target - sequence, given the input. - (targets, predictions) - (targets, - predictions, aux_values) - Note that `aux_values` refers to a dictionary - of auxiliary values that the model assigned to each sequence. - add_to_registry: if True, will register the new task. - - Returns: - A SeqIO task. - """ - - def dataset_fn(split, shuffle_files): - del split, shuffle_files - return get_dataset_from_natural_text_examples( - examples, - preprocessors=[], - task_feature_lengths=interactive_model._task_feature_lengths, # pylint: disable=protected-access - features={}, - ) - - data_source = seqio.FunctionDataSource( - dataset_fn=dataset_fn, splits=["train", "validation"] - ) - - if add_to_registry: - seqio.TaskRegistry.add( - task_name, - data_source, - preprocessors=preprocessors, - output_features=interactive_model._features, # pylint: disable=protected-access - metric_fns=metric_fns, - ) - - return seqio.get_mixture_or_task(task_name) - - -# pylint: disable=protected-access -def get_gin_config_from_interactive_model( - interactive_model: InteractiveModel, - script_type: T5XScriptType, - task_name: str, - partitioner_config_str: str, - model_config_str: str, - train_steps: int = 1, - imports_str: str = "", -): - """Converts an InteractiveModel instance into a Gin config string. - - This function will be used to graduate people to the T5X/SeqIO-based - train/infer/eval scripts. - - Args: - interactive_model: an instance of the InteractiveModel. - script_type: which T5X script the Gin config should function with. - task_name: the name of the SeqIO task to be used. - partitioner_config_str: a string that defines the Partitioner object in the - Gin config. - model_config_str: a string that defines the Model object in the Gin config. - train_steps: the number of steps to train for, only used if FINETUNING or - PRETRAINING is selected as the script type. - imports_str: if the `model_config_str` or `partitioner_config_str` relies on - some other files to be imported, these import statements can be included - in the final Gin file by adding them to this string. - - Returns: - A string that contains the full Gin file to be used for train/infer/eval.py. - """ - restore_config_str = "" - if interactive_model._restore_checkpoint_cfg: - restore_config_str = f"""CHECKPOINT_PATH = '{interactive_model._restore_checkpoint_cfg.path}' -utils.RestoreCheckpointConfig: - path = %CHECKPOINT_PATH - mode = '{interactive_model._restore_checkpoint_cfg.mode}' - dtype = '{interactive_model._restore_checkpoint_cfg.dtype}'""" - - base_config_str = f""" -{imports_str} - -MODEL_DIR = "{interactive_model._output_dir}" -MIXTURE_OR_TASK_NAME = "{task_name}" -TASK_FEATURE_LENGTHS = {interactive_model._task_feature_lengths} -USE_CACHED_TASKS = False -SHUFFLE_TRAIN_EXAMPLES = False -BATCH_SIZE = {interactive_model._batch_size} - -{model_config_str} -{partitioner_config_str} -{restore_config_str}""" - - if script_type == T5XScriptType.INFERENCE: - if not interactive_model._restore_checkpoint_cfg: - raise ValueError("A checkpoint must be provided to run inference.") - gin_config = f""" -include 't5x/configs/runs/infer.gin' -{base_config_str} - -INFER_OUTPUT_DIR = %MODEL_DIR - -utils.DatasetConfig: - use_cached = %USE_CACHED_TASKS - batch_size = %BATCH_SIZE - shuffle = False - seed = 0 - pack = False -""" - elif ( - script_type == T5XScriptType.FINETUNING - or script_type == T5XScriptType.PRETRAINING - or script_type == T5XScriptType.EVALUATION - ): - gin_config = f""" -from __gin__ import dynamic_registration - -import __main__ as train_script -from t5x import utils - -include 't5x/configs/runs/pretrain.gin' -{base_config_str} -utils.SaveCheckpointConfig: - period = {interactive_model._save_checkpoint_cfg.period} - dtype = '{interactive_model._save_checkpoint_cfg.dtype}' - keep = {interactive_model._save_checkpoint_cfg.keep} - save_dataset = {interactive_model._save_checkpoint_cfg.save_dataset} - -TRAIN_STEPS = {train_steps} -SHUFFLE_TRAIN_EXAMPLES = False -DROPOUT_RATE = 0.0 - -train/utils.DatasetConfig: - pack = False - -train_eval/utils.DatasetConfig: - pack = False -""" - if script_type == T5XScriptType.EVALUATION: - gin_config += """ -train_script.train: - run_eval_before_training = True - eval_period = 0 - total_steps = 0 -""" - return gin_config - - -# pylint: enable=protected-access diff --git a/t5x-main/t5x/losses.py b/t5x-main/t5x/losses.py deleted file mode 100644 index 8915c813feaa5c12570ecddb74d3f93c2328c4b6..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/losses.py +++ /dev/null @@ -1,356 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Loss functions.""" - -import enum -from typing import Mapping, Optional, Tuple, Union - -from flax.training import common_utils -import jax -import jax.numpy as jnp -import numpy as np - - -@jax.custom_vjp -def cross_entropy_with_logits( - logits: jnp.ndarray, targets: jnp.ndarray, z_loss: float -) -> Tuple[jnp.ndarray, jnp.ndarray]: - """Computes cross entropy loss with stable custom gradient. - - Computes a stabilized-gradient version of: - -jnp.sum(targets * nn.log_softmax(logits), axis=-1) - - If z_loss > 0, then an auxiliary loss equal to z_loss*log(z)^2 - will be added to the cross entropy loss (z = softmax normalization constant). - The two uses of z_loss are: - 1. To keep the logits from drifting too far from zero, which can cause - unacceptable roundoff errors in bfloat16. - 2. To encourage the logits to be normalized log-probabilities. - - Args: - logits: [batch, length, num_classes] float array. - targets: categorical one-hot targets [batch, length, num_classes] float - array. - z_loss: coefficient for auxilliary z-loss loss term. - - Returns: - tuple with the total loss and the z_loss, both - float arrays with shape [batch, length]. - """ - logits_sum = jax.scipy.special.logsumexp(logits, axis=-1, keepdims=True) - log_softmax = logits - logits_sum - loss = -jnp.sum(targets * log_softmax, axis=-1) - # Add auxilliary z-loss term. - log_z = jnp.squeeze(logits_sum, axis=-1) - total_z_loss = z_loss * jax.lax.square(log_z) - loss += total_z_loss - return loss, total_z_loss - - -def _cross_entropy_with_logits_fwd( - logits: jnp.ndarray, targets: jnp.ndarray, z_loss: float = 0.0 -) -> Tuple[ - Tuple[jnp.ndarray, jnp.ndarray], - Tuple[ - jnp.ndarray, - jnp.ndarray, - jnp.ndarray, - jnp.ndarray, - jnp.ndarray, - jnp.ndarray, - jnp.ndarray, - ], -]: - """Forward-mode of `cross_entropy_with_logits`.""" - max_logit = logits.max(axis=-1, keepdims=True) - shifted = logits - max_logit - exp_shifted = jnp.exp(shifted) - sum_exp = jnp.sum(exp_shifted, axis=-1, keepdims=True) - log_softmax = shifted - jnp.log(sum_exp) - loss = -jnp.sum(targets * log_softmax, axis=-1) - # Add auxilliary z-loss term. - log_z = jnp.squeeze(jnp.log(sum_exp) + max_logit, axis=-1) - total_z_loss = z_loss * jax.lax.square(log_z) - loss += total_z_loss - return (loss, total_z_loss), ( - logits, - targets, - z_loss, - exp_shifted, - sum_exp, # pytype: disable=bad-return-type # jax-ndarray - log_softmax, - log_z, - ) - - -def _cross_entropy_with_logits_bwd( - res: Tuple[ - jnp.ndarray, - jnp.ndarray, - jnp.ndarray, - jnp.ndarray, - jnp.ndarray, - jnp.ndarray, - jnp.ndarray, - ], - g: Tuple[jnp.ndarray, jnp.ndarray], -) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: - """Backward-mode of `cross_entropy_with_logits`.""" - g = g[0] # Ignore z_loss component as that is only used for logging. - logits, targets, z_loss, exp_shifted, sum_exp, log_softmax, log_z = res - # z-loss term adds the (2 * z_loss * log_z) factor. - deriv = ( - jnp.expand_dims(1 + 2 * z_loss * log_z, -1) * exp_shifted / sum_exp - - targets - ) - g_logits = jnp.expand_dims(g, axis=-1) * deriv - g_targets = -jnp.expand_dims(g, axis=-1) * log_softmax - return ( - jnp.asarray(g_logits, logits.dtype), - jnp.asarray(g_targets, targets.dtype), - jnp.array(0.0), - ) # sets z-loss coeff gradient to 0 - - -cross_entropy_with_logits.defvjp( - _cross_entropy_with_logits_fwd, _cross_entropy_with_logits_bwd -) - - -def compute_weighted_cross_entropy( - logits: jnp.ndarray, - targets: jnp.ndarray, - weights: Optional[jnp.ndarray] = None, - label_smoothing: float = 0.0, - z_loss: float = 0.0, - loss_normalizing_factor: Optional[float] = None, -) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: - """Compute weighted cross entropy and entropy for log probs and targets. - - Args: - logits: [batch, length, num_classes] float array. - targets: categorical targets [batch, length] int array. - weights: None or array of shape [batch, length]. - label_smoothing: label smoothing constant, used to determine the on and off - values. - z_loss: coefficient for auxiliary z-loss loss term. - loss_normalizing_factor: Constant to divide loss by. If not specified, loss - will not be normalized. Intended for backward compatibility with T5-MTF - training. Should not normally be used. - - Returns: - Tuple of scalar loss, z_loss, and weight sum. - """ - if logits.ndim != targets.ndim + 1: - raise ValueError( - 'Incorrect shapes. Got shape %s logits and %s targets' - % (str(logits.shape), str(targets.shape)) - ) - vocab_size = logits.shape[-1] - confidence = 1.0 - label_smoothing - if vocab_size == 1: - low_confidence = 1.0 - confidence - else: - low_confidence = (1.0 - confidence) / (vocab_size - 1) - normalizing_constant = -( - confidence * jnp.log(confidence) - + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20) - ) - soft_targets = common_utils.onehot( - targets, vocab_size, on_value=confidence, off_value=low_confidence - ) - total_loss, total_z_loss = cross_entropy_with_logits( - logits, soft_targets, z_loss=z_loss - ) - total_loss = total_loss - normalizing_constant - - weight_sum = np.prod(targets.shape) - if weights is not None: - total_loss = total_loss * weights - total_z_loss = total_z_loss * weights - weight_sum = jnp.sum(weights) - - # By default, we do not normalize loss based on anything. - # We don't normalize based on batch size because the optimizers we use are - # pretty much scale invariant, so this simplifies things. - # We don't normalize based on number of non-padding tokens in order to treat - # each token as equally important regardless of sequence length. - if loss_normalizing_factor is not None: - total_loss /= loss_normalizing_factor - total_z_loss /= loss_normalizing_factor - return jnp.sum(total_loss), jnp.sum(total_z_loss), weight_sum - - -@enum.unique -class SpecialLossNormalizingFactor(enum.Enum): - """Specially calculated loss_normalizing_factors, that are not a constant. - - Attributes: - NUM_REAL_TARGET_TOKENS: Whether to divide the loss by the number of real - (non-padding) tokens in the current target batch. If - 'decoder_loss_weights' are specified, it will be the sum of the weights. - Otherwise it will be the number of non-zero 'decoder_target_tokens'. - NUM_TOTAL_TARGET_TOKENS: Whether to divide the loss by the total number of - target tokens, i.e., batch_size * target_seq_length (including padding). - AVERAGE_PER_SEQUENCE: This will first compute the per-sequence loss - (averaged over the number of real target tokens in the sequence), and then - compute the average of that over the sequences. This can be preferable to - NUM_REAL_TARGET_TOKENS for finetuning, because it will weigh all examples - equally, regardless of sequence length (which can be especially important - for multi-task finetuning). - """ - - NUM_REAL_TARGET_TOKENS = 1 - NUM_TOTAL_TARGET_TOKENS = 2 - AVERAGE_PER_SEQUENCE = 3 - - -def convert_special_loss_normalizing_factor_to_enum( - x: str, -) -> SpecialLossNormalizingFactor: - """Converts stringified version of LNF to an enum. - - This is useful because gin dynamic registration does not (currently) - have support for enum. - - Args: - x: stringified version of SpecialLossNormalizingFactor enum. - - Returns: - SpecialLossNormalizingFactor enum instance. - """ - x = x.upper() - if x == 'NUM_REAL_TARGET_TOKENS': - return SpecialLossNormalizingFactor.NUM_REAL_TARGET_TOKENS - if x == 'NUM_TOTAL_TARGET_TOKENS': - return SpecialLossNormalizingFactor.NUM_TOTAL_TARGET_TOKENS - if x == 'AVERAGE_PER_SEQUENCE': - return SpecialLossNormalizingFactor.AVERAGE_PER_SEQUENCE - raise ValueError( - 'Could not convert string "%s" to SpecialLossNormalizingFactor' % x - ) - - -@jax.vmap -def _sum_weights_per_segment( - positions: jnp.ndarray, segment_ids: jnp.ndarray, weights: jnp.ndarray -) -> jnp.ndarray: - """Sums weights per packed segment to produce a normalizing vector.""" - - # NB: Assumes padding only occurs at the end of a sequence. - - def _repeat_last_nonnegative(xs, reverse=False): - def fn(prev, x): - y = jnp.where(x == 0, prev, x) - return y, y - - return jax.lax.scan(fn, jnp.zeros_like(xs[0]), xs, reverse=reverse)[1] - - # Compute final positions per sequence. - start_positions = positions == 0 - final_positions = jnp.concatenate([start_positions[1:], jnp.ones(1)]) - # Clear padded positions. - final_positions *= segment_ids != 0 - # Compute cumulative weights, clearing all but final position per sequence. - final_cumulative_weights = final_positions * jnp.cumsum(weights) - # Subtract sequences' final weights from cumulative weights of following ones. - final_total_weights = jnp.concatenate([ - final_cumulative_weights[0:1], - jnp.diff(_repeat_last_nonnegative(final_cumulative_weights)), - ]) - # Copy final sequence weight to all positions in sequence. - normalizer = _repeat_last_nonnegative(final_total_weights, reverse=True) - return normalizer - - -def get_loss_normalizing_factor_and_weights( - loss_normalizing_factor: Optional[ - Union[float, int, str, SpecialLossNormalizingFactor] - ], - batch: Mapping[str, jnp.ndarray], -): - """Get the float loss_normalizing_factor and loss weights. - - If loss_normalizing_factor is float or None, this will simply return the - input loss_normalizing_factor and batch. - - If loss_normalizing_factor is a SpecialLossNormalizingFactor, it will - return a float loss_normalizing_factor and loss weights corresponding to - the special LNF. See SpecialLossNormalizingFactor for more details. - - Args: - loss_normalizing_factor: The input LNF, which may be a float, None, or - SpecialLossNormalizingFactor (or a stringified SLNF). - batch: Input data batch. - - Returns: - Tuple of (output_loss_normalizing_factor, loss_weights). - 'output_loss_normalizing_factor' is a scalar float (Python float - or jnp float). - 'loss_weights' is the per token loss weight JNP array. - """ - - loss_weights = batch.get('decoder_loss_weights', None) - if loss_normalizing_factor is None or not isinstance( - loss_normalizing_factor, (str, SpecialLossNormalizingFactor) - ): - return (loss_normalizing_factor, loss_weights) - - if isinstance(loss_normalizing_factor, str): - loss_normalizing_factor = convert_special_loss_normalizing_factor_to_enum( - loss_normalizing_factor - ) - - # If `loss_weights` are not provided, we assume that the padding id is 0 and - # that non-padding tokens in the decoder all correspond to the positions - # where loss should be taken. If more fine-grained behavior (e.g., taking - # loss on subset of 'decoder_target_tokens') is desired, provide - # `loss_weights` that account for this. - if loss_weights is None: - loss_weights = jnp.asarray(batch['decoder_target_tokens'] > 0, jnp.float32) - - output_normalizing_factor = None - if ( - loss_normalizing_factor - == SpecialLossNormalizingFactor.NUM_REAL_TARGET_TOKENS - ): - output_normalizing_factor = jnp.sum(loss_weights) - elif ( - loss_normalizing_factor - == SpecialLossNormalizingFactor.NUM_TOTAL_TARGET_TOKENS - ): - output_normalizing_factor = np.prod(batch['decoder_target_tokens'].shape) - elif ( - loss_normalizing_factor - == SpecialLossNormalizingFactor.AVERAGE_PER_SEQUENCE - ): - if 'decoder_segment_ids' in batch: # is packed - norm_vec = _sum_weights_per_segment( - batch['decoder_positions'], batch['decoder_segment_ids'], loss_weights - ) - else: - norm_vec = jnp.sum(loss_weights, axis=-1, keepdims=True) - # Handle divide-by-zero. - loss_weights = jnp.nan_to_num( - loss_weights / norm_vec, nan=0, posinf=0, neginf=0 - ) - output_normalizing_factor = jnp.sum(loss_weights) - else: - raise ValueError( - 'Unsupported value of loss_normalizing_factor: %s' - % str(loss_normalizing_factor) - ) - - return (output_normalizing_factor, loss_weights) diff --git a/t5x-main/t5x/losses_test.py b/t5x-main/t5x/losses_test.py deleted file mode 100644 index 0ad0823abd20c91e44bb858c842a9251cfe55fb1..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/losses_test.py +++ /dev/null @@ -1,211 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for t5x.losses.""" - -from absl.testing import absltest -import jax -import jax.numpy as jnp -import numpy as np -from t5x import losses - - -class LossTest(absltest.TestCase): - - def test_xent(self): - def lossfn(logits, targets, weights): - loss, z_loss, weight_sum = losses.compute_weighted_cross_entropy( - logits, - targets, - weights, - label_smoothing=0.1, - z_loss=0.1, - loss_normalizing_factor=0.1, - ) - return loss, (z_loss, weight_sum) - - batch_size = 2 - length = 4 - vocab_size = 8 - logits = np.random.normal(size=(batch_size, length, vocab_size)).astype( - np.float32 - ) - targets = np.random.randint(0, vocab_size, size=(batch_size, length)) - weights = np.ones_like(targets) - out = jax.jit(jax.value_and_grad(lossfn, has_aux=True))( - logits, targets, weights - ) - (loss, (z_loss, weight_sum)), dlogits = out - # Just a smoke test for now - # TODO(t5x): Expand test - print(jax.device_get(((loss, (z_loss, weight_sum)), dlogits))) - - -class SpecialLossNormalizingFactorTest(absltest.TestCase): - - def test_num_real_target_tokens(self): - batch = { - 'decoder_target_tokens': jnp.asarray( - [[1, 2, 3, 4, 0], [5, 6, 0, 0, 0]], jnp.int32 - ) - } - - (output_lnf, output_loss_weights) = ( - losses.get_loss_normalizing_factor_and_weights( - loss_normalizing_factor=losses.SpecialLossNormalizingFactor.NUM_REAL_TARGET_TOKENS, - batch=batch, - ) - ) - - np.testing.assert_allclose(output_lnf, 6.0, rtol=1e-3) - np.testing.assert_allclose( - output_loss_weights, - np.array( - [[1.0, 1.0, 1.0, 1.0, 0.0], [1.0, 1.0, 0.0, 0.0, 0.0]], - dtype=np.float32, - ), - rtol=1e-3, - ) - - def test_num_total_target_tokens(self): - batch = { - 'decoder_target_tokens': jnp.asarray( - [[1, 2, 3, 4, 0], [5, 6, 0, 0, 0]], jnp.int32 - ) - } - - (output_lnf, output_loss_weights) = ( - losses.get_loss_normalizing_factor_and_weights( - loss_normalizing_factor=losses.SpecialLossNormalizingFactor.NUM_TOTAL_TARGET_TOKENS, - batch=batch, - ) - ) - - np.testing.assert_allclose(output_lnf, 10.0, rtol=1e-3) - np.testing.assert_allclose( - output_loss_weights, - np.array( - [[1.0, 1.0, 1.0, 1.0, 0.0], [1.0, 1.0, 0.0, 0.0, 0.0]], - dtype=np.float32, - ), - rtol=1e-3, - ) - - def test_average_per_sequence(self): - batch = { - 'decoder_target_tokens': jnp.asarray( - [[1, 2, 3, 4, 0], [5, 6, 0, 0, 0]], jnp.int32 - ) - } - - (output_lnf, output_loss_weights) = ( - losses.get_loss_normalizing_factor_and_weights( - loss_normalizing_factor=losses.SpecialLossNormalizingFactor.AVERAGE_PER_SEQUENCE, - batch=batch, - ) - ) - - np.testing.assert_allclose(output_lnf, 2.0, rtol=1e-3) - np.testing.assert_allclose( - output_loss_weights, - jnp.asarray( - [[0.25, 0.25, 0.25, 0.25, 0.0], [0.5, 0.5, 0.0, 0.0, 0.0]], - jnp.float32, - ), - rtol=1e-3, - ) - - def test_average_per_sequence_with_weights(self): - batch = { - 'decoder_target_tokens': jnp.asarray( - [[1, 2, 3, 4, 0], [5, 6, 0, 0, 0]], jnp.int32 - ), - 'decoder_loss_weights': jnp.asarray( - [[0.5, 1.0, 0.25, 2.0, 0.0], [1.0, 1.0, 0.0, 0.0, 0.0]], jnp.float32 - ), - } - - (output_lnf, output_loss_weights) = ( - losses.get_loss_normalizing_factor_and_weights( - loss_normalizing_factor=losses.SpecialLossNormalizingFactor.AVERAGE_PER_SEQUENCE, - batch=batch, - ) - ) - - np.testing.assert_allclose(output_lnf, 2.0, rtol=1e-3) - np.testing.assert_allclose( - output_loss_weights, - jnp.asarray( - [ - [0.5 / 3.75, 1.0 / 3.75, 0.25 / 3.75, 2.0 / 3.75, 0.0], - [1.0 / 2.0, 1.0 / 2.0, 0.0, 0.0, 0.0], - ], - jnp.float32, - ), - rtol=1e-3, - ) - - def test_sum_weights_per_segment(self): - weights = jnp.asarray( - [[0.5, 1.0, 0.25, 2.0, 1.5], [1.0, 2.0, 3.0, 4.0, 5.0]], jnp.float32 - ) - positions = jnp.asarray([[0, 1, 2, 0, 0], [0, 0, 1, 0, 0]]) - segment_ids = jnp.asarray([[1, 1, 1, 2, 3], [1, 2, 2, 3, 0]]) - - norm_vec = losses._sum_weights_per_segment(positions, segment_ids, weights) - - np.testing.assert_allclose( - norm_vec, - jnp.asarray( - [[1.75, 1.75, 1.75, 2.0, 1.5], [1.0, 5.0, 5.0, 4.0, 0.0]], - jnp.float32, - ), - rtol=1e-3, - ) - - def test_average_per_sequence_with_weights_with_packing(self): - batch = { - 'decoder_target_tokens': jnp.asarray( - [[1, 2, 3, 4, 5], [5, 6, 7, 8, 0]], jnp.int32 - ), - 'decoder_loss_weights': jnp.asarray( - [[0.5, 1.0, 0.25, 2.0, 1.5], [1.0, 2.0, 3.0, 4.0, 5.0]], jnp.float32 - ), - 'decoder_positions': jnp.asarray([[0, 1, 2, 0, 0], [0, 0, 1, 0, 0]]), - 'decoder_segment_ids': jnp.asarray([[1, 1, 1, 2, 3], [1, 2, 2, 3, 0]]), - } - - (output_lnf, output_loss_weights) = ( - losses.get_loss_normalizing_factor_and_weights( - loss_normalizing_factor=losses.SpecialLossNormalizingFactor.AVERAGE_PER_SEQUENCE, - batch=batch, - ) - ) - - np.testing.assert_allclose(output_lnf, 6.0, rtol=1e-3) - np.testing.assert_allclose( - output_loss_weights, - jnp.asarray( - [ - [0.5 / 1.75, 1.0 / 1.75, 0.25 / 1.75, 2.0 / 2.0, 1.5 / 1.5], - [1.0 / 1.0, 2.0 / 5.0, 3.0 / 5.0, 4.0 / 4.0, 0.0], - ], - jnp.float32, - ), - rtol=1e-3, - ) - - -if __name__ == '__main__': - absltest.main() diff --git a/t5x-main/t5x/main.py b/t5x-main/t5x/main.py deleted file mode 100644 index 17bad3bb12fd4278782f06207a443c7069672577..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/main.py +++ /dev/null @@ -1,208 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -r"""The main entrance for running any of the T5X supported binaries. - -Currently this includes train/infer/eval/precompile. - -Example Local (CPU) Pretrain Gin usage - -python -m t5x.main \ - --gin_file=t5x/examples/t5/t5_1_1/tiny.gin \ - --gin_file=t5x/configs/runs/pretrain.gin \ - --gin.MODEL_DIR=\"/tmp/t5x_pretrain\" \ - --gin.TRAIN_STEPS=10 \ - --gin.MIXTURE_OR_TASK_NAME=\"c4_v220_span_corruption\" \ - --gin.MIXTURE_OR_TASK_MODULE=\"t5.data.mixtures\" \ - --gin.TASK_FEATURE_LENGTHS="{'inputs': 128, 'targets': 30}" \ - --gin.DROPOUT_RATE=0.1 \ - --run_mode=train \ - --logtostderr -""" - -import concurrent.futures # pylint:disable=unused-import -import enum -import importlib -import os -import sys -from typing import Optional, Sequence - -from absl import app -from absl import flags -from absl import logging - -import fiddle as fdl -import gin -import seqio - -from t5x import config_utils -from t5x import gin_utils -from t5x import utils - - -@enum.unique -class RunMode(enum.Enum): - """All the running mode possible in T5X.""" - - TRAIN = 'train' - EVAL = 'eval' - INFER = 'infer' - PRECOMPILE = 'precompile' - EXPORT = 'export' - - -_GIN_FILE = flags.DEFINE_multi_string( - 'gin_file', - default=None, - help=( - 'Path to gin configuration file. Multiple paths may be passed and ' - 'will be imported in the given order, with later configurations ' - 'overriding earlier ones.' - ), -) - -_GIN_BINDINGS = flags.DEFINE_multi_string( - 'gin_bindings', default=[], help='Individual gin bindings.' -) - -_GIN_SEARCH_PATHS = flags.DEFINE_list( - 'gin_search_paths', - default=['.'], - help=( - 'Comma-separated list of gin config path prefixes to be prepended ' - 'to suffixes given via `--gin_file`. If a file appears in. Only the ' - 'first prefix that produces a valid path for each suffix will be ' - 'used.' - ), -) - -_RUN_MODE = flags.DEFINE_enum_class( - 'run_mode', - default=None, - enum_class=RunMode, - help='The mode to run T5X under', -) - -_TFDS_DATA_DIR = flags.DEFINE_string( - 'tfds_data_dir', - None, - 'If set, this directory will be used to store datasets prepared by ' - 'TensorFlow Datasets that are not available in the public TFDS GCS ' - 'bucket. Note that this flag overrides the `tfds_data_dir` attribute of ' - 'all `Task`s.', -) - -_DRY_RUN = flags.DEFINE_bool( - 'dry_run', - False, - 'If set, does not start the function but stil loads and logs the config.', -) - - -FLAGS = flags.FLAGS - -# Automatically search for gin files relative to the T5X package. -_DEFAULT_GIN_SEARCH_PATHS = [ - os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -] - -# Mapping of run_mode to the attribute used in the imported module, e.g. -# {EVAL : 'evaluate'} will load 'evaluate' in eval.py. -_ATTR_BY_RUN_MODE = { - RunMode.TRAIN: 'train', - RunMode.EVAL: 'evaluate', - RunMode.INFER: 'infer', - RunMode.PRECOMPILE: 'precompile', - RunMode.EXPORT: 'save', -} - -# Extra attributes to set in __main__ from the imported module. This is for -# backward compatibility with existing __main__ references in gin files. -_EXTRA_ATTRS_BY_RUN_MODE = {RunMode.INFER: ('create_task_from_tfexample_file',)} - - -main_module = sys.modules[__name__] - - -def main(argv: Sequence[str]): - if len(argv) > 1: - raise app.UsageError('Too many command-line arguments.') - - - if _RUN_MODE.value is None: - raise ValueError("'run_mode' flag must be specified when using main.py.") - # Dynamic import the modules based on run_mode, e.g. - # If _RUN_MODE.value is 'train', below is equivalent of doing: - # from t5x import train - # train = train.train - - # _RUN_MODE can never be None after this point. - # pytype: disable=attribute-error - lib_name = _RUN_MODE.value.name.lower() - import_attr = _ATTR_BY_RUN_MODE[_RUN_MODE.value] - # pytype: enable=attribute-error - - parent_module = 't5x' - - - module_to_import = f'{parent_module}.{lib_name}' - - logging.info('Dynamically importing : %s', module_to_import) - imported_lib = importlib.import_module(module_to_import) - - entry_func = getattr(imported_lib, import_attr) - setattr(main_module, import_attr, entry_func) - for attr_name in _EXTRA_ATTRS_BY_RUN_MODE.get(_RUN_MODE.value, ()): - setattr(main_module, attr_name, getattr(imported_lib, attr_name)) - - - if _TFDS_DATA_DIR.value is not None: - seqio.set_tfds_data_dir_override(_TFDS_DATA_DIR.value) - - - if config_utils.using_fdl(): - config = config_utils.config_with_fiddle(entry_func) - run_with_fdl = fdl.build(config) - - if _DRY_RUN.value: - return - - run_with_fdl() - else: - # Register function explicitly under __main__ module, to maintain backward - # compatability of existing '__main__' module references. - gin.register(entry_func, '__main__') - if _GIN_SEARCH_PATHS.value != ['.']: - logging.warning( - 'Using absolute paths for the gin files is strongly recommended.' - ) - - # User-provided gin paths take precedence if relative paths conflict. - gin_utils.parse_gin_flags( - _GIN_SEARCH_PATHS.value + _DEFAULT_GIN_SEARCH_PATHS, - _GIN_FILE.value, - _GIN_BINDINGS.value, - ) - - if _DRY_RUN.value: - return - - run_with_gin = gin.get_configurable(entry_func) - - run_with_gin() - - - -if __name__ == '__main__': - config_utils.run(main) diff --git a/t5x-main/t5x/metrics.py b/t5x-main/t5x/metrics.py deleted file mode 100644 index 2ae9e4fece726224ed5e4cc95a0ce6fb913609ed..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/metrics.py +++ /dev/null @@ -1,313 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""T5X Metrics. - -Defines Metric objects and collections used by T5X models. These objects use the -CLU metrics library -""" - -import dataclasses -from typing import MutableMapping, Optional, Union - -from clu import metrics as clu_metrics -import flax # Only used for flax.struct.dataclass. -import jax -import jax.numpy as jnp -import numpy as np - -MetricsMap = MutableMapping[str, clu_metrics.Metric] -Scalar = Union[int, float, np.number, np.ndarray, jnp.ndarray] - - -def _check_param(value, *, ndim=None, dtype=jnp.float32): - """Raises a `ValueError` if `value` does not match ndim/dtype. - - Args: - value: Value to be tested. - ndim: Expected dimensions. - dtype: Expected dtype. - - Raises: - A `ValueError` if `value` does not match `ndim` or `dtype`, or if `value` - is not an instance of `jnp.ndarray`. - """ - if ndim is not None and value.ndim != ndim: - raise ValueError(f"Expected ndim={ndim}, got ndim={value.ndim}") - if dtype is not None and value.dtype != dtype: - raise ValueError(f"Expected dtype={dtype}, got dtype={value.dtype}") - - -@flax.struct.dataclass -class Sum(clu_metrics.Metric): - """Computes the sum of a scalar or a batch of tensors. - - See also documentation of `Metric`. - """ - - total: Scalar - - @classmethod - def from_model_output(cls, values: Scalar, **_) -> clu_metrics.Metric: - """Initializes a Sum Metric from array (or singular) values. - - Args: - values: array of values to sum (or a single value). - - Returns: - A Sum object. - """ - values = jnp.asarray(values) - if values.ndim == 0: - values = values[None] - return cls(total=values.sum()) - - def merge(self, other: "Sum") -> "Sum": - return type(self)(total=self.total + other.total) - - def compute(self) -> jnp.ndarray: - return jnp.array(self.total) - - -@flax.struct.dataclass -class Step(clu_metrics.Metric): - """Abstract class representing a per-step or step-per metric. - - Tracks number of steps. Must be set manually using replace_steps, since the - use of microbatches may otherwise cause the computation to be incorrect. - - See also documentation of `Metric`. - """ - steps: Optional[int] = 1 - - def replace_steps(self, steps: int) -> "Step": - return self.replace(steps=steps) - - def compute(self) -> jnp.ndarray: - if self.steps is None: - raise ValueError( - "`steps` must be set by calling `replace_steps` before computing metric." - ) - return jnp.array(self.steps) - - -@flax.struct.dataclass -class AveragePerStep(Step): - """Represents per-step average (total divided by number of steps). - - See also documentation of `Step`. - """ - total: Optional[Scalar] = None - - @classmethod - def from_model_output(cls, - values: Scalar, - steps: Optional[int] = 1, - **_) -> clu_metrics.Metric: - """Initializes an AveragePerStep Metric from array (or singular) values. - - Args: - values: array of values to sum (or a single value). - steps: number of steps, defaults to 1. - - Returns: - AveragePerStep object. - """ - values = jnp.asarray(values) - if values.ndim == 0: - values = values[None] - return cls(total=values.sum(), steps=steps) - - def merge(self, other: "AveragePerStep") -> "AveragePerStep": - assert type(self) is type(other) - return type(self)( - total=self.total + other.total, steps=self.steps + other.steps) - - def compute(self) -> jnp.ndarray: - steps = super().compute() - if self.total is None: - raise ValueError("`AveragePerStep` `total` cannot be None.") - return self.total / steps - - -@flax.struct.dataclass -class Time(clu_metrics.Metric): - """Computes the sum of a float-valued metric over a period of time. - - Duration (the denominator) must be set manually. This is because JAX does not - properly support time functions inside compiled functions. Calling time.time() - inside a compiled function results in the stored time being the compilation - time, not the run time. - - See also documentation of `Metric`. - """ - duration: Optional[Scalar] = None - - def merge(self, other: "Time") -> "Time": - return self - - def compute(self) -> jnp.ndarray: - if self.duration is None: - raise ValueError( - "`Time` `duration` must be set by calling `replace_duration` before computing." - ) - return jnp.array(self.duration) - - def replace_duration(self, duration: Scalar) -> "Time": - """Replaces duration with the given value. - - Should be used outside a compiled function to set the duration of the - metric. - - Args: - duration: metric duration - - Returns: - A new Time object. - """ - return self.replace(duration=duration) - - -@flax.struct.dataclass -class TimeRate(Time): - """Computes the sum of a float-valued metric over a period of time. - - Duration (the denominator) must be set using replace_duration. This is because - JAX does not properly support time functions inside compiled functions. - Calling time.time() inside a compiled function results in the stored time - being the compilation time, not the run time. - - See also documentation of `Time` and `Metric`. - """ - - numerator: Optional[jnp.ndarray] = None - - @classmethod - def from_model_output(cls, numerator: float, **_) -> clu_metrics.Metric: - """Initializes a TimeRate Metric from a float value (the numerator). - - Args: - numerator: a float (numerator of the metric) - - Returns: - A TimeRate object. - """ - return cls(numerator=numerator) # pytype: disable=wrong-arg-types # jax-ndarray - - def merge(self, other: "TimeRate") -> "TimeRate": - assert_msg = "Merging with non-None durations is currently not supported." - assert self.duration is None and other.duration is None, assert_msg - return type(self)(numerator=self.numerator + other.numerator) - - def compute(self) -> jnp.ndarray: - duration = super().compute() - return self.numerator / duration - - -@flax.struct.dataclass -class StepsPerTime(Step, Time): - """Represents a metric computed as number of steps per time. - - See also documentation of `Step`. - """ - - @classmethod - def from_model_output(cls, - steps: Optional[int] = 1, - **_) -> clu_metrics.Metric: - """Initializes an StepsPerTime Metric. - - Args: - steps: number of steps, defaults to 1. - - Returns: - StepsPerTime object. - """ - return cls(steps=steps) - - def merge(self, other: "StepsPerTime") -> "StepsPerTime": - assert type(self) is type(other) - return type(self)(steps=self.steps + other.steps) - - def compute(self) -> jnp.ndarray: - steps = Step.compute(self) - duration = Time.compute(self) - return steps / duration - - -def is_metric_obj(obj): - return isinstance(obj, clu_metrics.Metric) - - -def is_time_metric(obj): - return isinstance(obj, Time) - - -def create_metrics_dict(float_metrics_dict): - """Input: dict{str: float} | Output: dict{str: Metric}.""" - return {k: Sum(v) for k, v in float_metrics_dict.items()} - - -def shape_obj_to_defined_obj(obj: clu_metrics.Metric): - """Converts shapes in Metric to zero arrays. - - obj should be a Metric object subclass where each member variable is a - ShapeDtypeStruct (from jax.eval_shape). A new object of the same class where - each member variable is an array of zeros with the same shape and type as - the corresponding variable defined by ShapeDtypeStruct. - - Args: - obj: a clu.metrics.Metric object where each member variable is a - ShapeDtypeStruct (from jax.eval_shape) - - Returns: - A Metric object with class variables initialized as zero arrays. - """ - - def class_attr_shape(a): - attr = getattr(obj, a.name) - if isinstance(attr, clu_metrics.Metric): - return shape_obj_to_defined_obj(attr) - else: - if hasattr(attr, "shape"): - return jnp.zeros(shape=attr.shape, dtype=attr.dtype) - else: - return attr - - return obj.__class__( - **{a.name: class_attr_shape(a) for a in dataclasses.fields(obj)}) # pytype: disable=wrong-arg-types # re-none - - -def set_time_metrics_duration(metrics, duration): - """Sets duration for TimeRate objects in metrics pytree.""" - - def fn(o): - if isinstance(o, Time): - return o.replace_duration(duration) - else: - return o - - return jax.tree.map(fn, metrics, is_leaf=lambda obj: isinstance(obj, Time)) - - -def set_step_metrics_num_steps(metrics, num_steps): - """Sets steps for Step objects in metrics pytree.""" - - def fn(o): - if isinstance(o, Step): - return o.replace_steps(num_steps) - else: - return o - - return jax.tree.map(fn, metrics, is_leaf=is_metric_obj) diff --git a/t5x-main/t5x/metrics_test.py b/t5x-main/t5x/metrics_test.py deleted file mode 100644 index 34879a4ead1929dc4a6615b699652f864ac9c8e6..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/metrics_test.py +++ /dev/null @@ -1,90 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for clu.metrics.""" - -from absl.testing import absltest -from absl.testing import parameterized -import jax.numpy as jnp -import numpy as np -from t5x import metrics - - -class MetricsTest(parameterized.TestCase): - - @parameterized.named_parameters( - ("0d_values", 2.0, 2.0), - ("1d_values", [1, 2, 3], 6.0), - ("2d_values", [[1, 2], [2, 3], [3, 4]], 15.0), - ( - "3d_values", - [[[1, 2], [2, 3]], [[2, 1], [3, 4]], [[3, 1], [4, 1]]], - 27.0, - ), - ) - def test_sum(self, values, expected_result): - self.assertAlmostEqual( - metrics.Sum.from_model_output(values).compute(), expected_result - ) - - def test_time_rate(self): - value = np.array([3.0]) - duration = 2.0 - metric = metrics.TimeRate.from_model_output(value).replace_duration( - duration - ) - self.assertAlmostEqual(metric.compute(), value / duration) - - def test_time_rate_unset_duration(self): - value = jnp.array([3.0]) - metric = metrics.TimeRate.from_model_output(value) - with self.assertRaises(ValueError): - metric.compute() - - def test_time(self): - duration = 2.0 - metric = metrics.Time().replace_duration(duration) - self.assertAlmostEqual(metric.compute(), duration) - - def test_time_unset_duration(self): - metric = metrics.Time() - with self.assertRaises(ValueError): - metric.compute() - - @parameterized.named_parameters( - ("0d_values", 2.0, 2.0), - ("1d_values", [1, 2, 3], 6.0), - ) - def test_average_per_step(self, values, expected_result): - a = metrics.AveragePerStep.from_model_output(values) - m = metrics.set_step_metrics_num_steps({"a": a}, 1) - self.assertAlmostEqual(m["a"].compute(), expected_result) - - steps = 5 - b = metrics.AveragePerStep.from_model_output(values, steps=steps) - m = metrics.set_step_metrics_num_steps({"b": b}, steps) - self.assertAlmostEqual(m["b"].compute(), expected_result / steps) - - def test_steps_per_time(self): - steps = 8.0 - duration = 2.0 - metric = metrics.StepsPerTime.from_model_output( - steps=steps - ).replace_duration(duration) - metrics_dict = metrics.set_step_metrics_num_steps({"metric": metric}, steps) - self.assertAlmostEqual(metrics_dict["metric"].compute(), steps / duration) - - -if __name__ == "__main__": - absltest.main() diff --git a/t5x-main/t5x/models.py b/t5x-main/t5x/models.py deleted file mode 100644 index 3b9fb052fa06ad078adbc8ca9130a97439c75fcd..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/models.py +++ /dev/null @@ -1,1477 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""T5X Models. - -This module uses layers.py to build a higher-level model structure and define -methods for the loss computation as well as a train, prediction, and evaluation -steps. -""" - -import abc -import dataclasses -import functools -import inspect -from typing import Any, Callable, Mapping, MutableMapping, Optional, Tuple, Union - -from absl import logging -import clu.metrics as clu_metrics -from flax import core as flax_core -from flax import linen as nn -from flax.core import scope as flax_scope -from flax.linen import partitioning as flax_partitioning -from flax.training import common_utils -import jax -import jax.numpy as jnp -import numpy as np -import seqio -from t5x import decoding -from t5x import losses -from t5x import metrics as metrics_lib -from t5x import optimizers -import tensorflow as tf -import typing_extensions - -# Remove _ShardedDeviceArray when users of t5x have their types updated -_ShardedDeviceArray = Any -Array = Union[np.ndarray, jnp.ndarray, _ShardedDeviceArray, tf.Tensor] -MetricsMap = metrics_lib.MetricsMap -PyTree = Any -PyTreeDef = jax.tree_util.PyTreeDef - - -class TokensIdsToLogitsCallable(typing_extensions.Protocol): - """Token ids to logits mapping call signature.""" - - def __call__( - self, decoding_state: decoding.DecodingState - ) -> Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]: - """Performs forward pass to convert token ids to logits. - - Args: - decoding_state: Current decoding state, including current token ids and - cache. - - Returns: - a tuple of logits with a shape [batch_size, vocab_size] and an updated - cache. - """ - ... - - -class DecodeFnCallable(typing_extensions.Protocol): - """Decoding function call signature.""" - - def __call__( - self, - *, - inputs: jnp.ndarray, - cache: Mapping[str, jnp.ndarray], - tokens_to_logits: TokensIdsToLogitsCallable, - eos_id: int, - num_decodes: int, - decode_rng: Optional[jax.Array], - cache_offset: int, - **kwargs, - ) -> Tuple[jnp.ndarray, jnp.ndarray]: - """Decoding function interface. - - Args: - inputs: [batch_size, max_decode_len] int32 sequence of tokens, with non-0 - prefix tokens to be used as a forced prompt. - cache: flax attention cache. - tokens_to_logits: fast autoregressive decoder function taking single token - slices and cache and returning next-token logits and updated cache. - eos_id: end-of-sentence token for target vocabulary. - num_decodes: number of decoded sequences to be returned. - decode_rng: an optional JAX PRNG Key for stochastic sampling routines. - cache_offset: axis offset for cache, arising from scanned layers. - **kwargs: an optional kwargs. One common usecase of this is passing - decoding parameters at the callsite. - - Returns: - decodes: Array of sequences: [batch_size, num_decodes, max_decode_len]. - The `num_decodes` dimension is expected to be sorted by the `scores`, - i.e., `decodes[:, -1, :] has the highest scores among `num_decodes` - decoded sequences. - scores: Array of log likelihood scores: [batch_size, num_decodes] - """ - ... - - -class BaseModel(abc.ABC): - """Abstract base class for models. - - Wraps a flax module to provide a basic interface for computing loss, - evaluation metrics, prediction, and scoring. - - Subclasses must implement the abstract methods. Any additional arguments added - to these methods must have defaults or be bound at run time to fit the - interface expected by the standard training, inference, and evaluation - functions. - """ - - FEATURE_CONVERTER_CLS: Callable[..., seqio.FeatureConverter] - - def __init__(self, optimizer_def: optimizers.OptimizerDefType): - # TODO(jbulian): Move the optimizer out of the model and make it a training - # parameter. - self.optimizer_def = optimizer_def - - @abc.abstractmethod - def loss_fn( - self, - params: PyTree, - batch: Mapping[str, jnp.ndarray], - dropout_rng: Optional[jax.Array], - ) -> Tuple[jnp.ndarray, MetricsMap]: - """Computes loss and metrics. - - Args: - params: model parameters. - batch: a batch of inputs. - dropout_rng: rng to use for dropout, or None for deterministic mode. - - Returns: - loss: the loss computed for the given inputs and parameters. - aux: - weight_sum: sum of the per-token weights applied to the loss. - metrics: a mapping of metrics computed for this batch. - """ - pass - - def eval_fn( - self, - params: PyTree, - batch: Mapping[str, jnp.ndarray], - ) -> Tuple[jnp.ndarray, MetricsMap]: - """Computes loss and metrics during the evaluation. - - Args: - params: model parameters. - batch: a batch of inputs. - - Returns: - loss: the loss computed for the given inputs and parameters. - aux: - weight_sum: sum of the per-token weights applied to the loss. - metrics: a mapping of metrics computed for this batch. - """ - return self.loss_fn( - params=params, - batch=batch, - dropout_rng=None, - ) - - def predict_batch( - self, - params: PyTree, - batch: Mapping[str, jnp.ndarray], - rng: Optional[jax.Array] = None, - ) -> jnp.ndarray: - """Predicts a batch of outputs from the model. - - Args: - params: model parameters. - batch: a batch of inputs. - rng: an optional RNG to use during prediction (e.g., for decoding). - - Returns: - The model predictions. - """ - return self.predict_batch_with_aux(params=params, batch=batch, rng=rng)[0] - - @abc.abstractmethod - def predict_batch_with_aux( - self, - params: PyTree, - batch: Mapping[str, jnp.ndarray], - rng: Optional[jax.Array] = None, - ) -> Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]: - """Predict a batch from the model with auxiliary outputs. - - Args: - params: model parameters. - batch: a batch of inputs. - rng: an optional RNG key to use during prediction (e.g., for decoding). - - Returns: - predictions: the model predictions - aux: auxiliary data - """ - pass - - @abc.abstractmethod - def score_batch( - self, - params: PyTree, - batch: Mapping[str, jnp.ndarray], - return_intermediates: bool = False, - ) -> jnp.ndarray: - """Computes scores for batch.""" - pass - - @abc.abstractmethod - def get_initial_variables( - self, - rng: jax.Array, - input_shapes: Mapping[str, Array], - input_types: Optional[Mapping[str, jnp.dtype]] = None, - ) -> flax_scope.FrozenVariableDict: - """Returns the initial variables of the model.""" - pass - - -class BaseTransformerModel(BaseModel): - """Abstract base class for Transformer models. - - Subclasses must implement `predict_batch_with_aux`, `score_batch`, - `get_initial_variables` from `BaseModel` as well as `_compute_logits`. - """ - - def __init__( - self, - module: nn.Module, - input_vocabulary: seqio.Vocabulary, - output_vocabulary: seqio.Vocabulary, - optimizer_def: optimizers.OptimizerDefType, - decode_fn: Optional[DecodeFnCallable] = None, - label_smoothing: float = 0.0, - z_loss: float = 0.0, - loss_normalizing_factor: Optional[ - Union[float, int, str, losses.SpecialLossNormalizingFactor] - ] = None, - ): - self.module = module - self._input_vocabulary = input_vocabulary - self._output_vocabulary = output_vocabulary - self._decode_fn = decode_fn - self._label_smoothing = label_smoothing - self._z_loss = z_loss - self._loss_normalizing_factor = loss_normalizing_factor - - super().__init__(optimizer_def=optimizer_def) - - @property - def input_vocabulary(self): - return self._input_vocabulary - - @property - def output_vocabulary(self): - return self._output_vocabulary - - @property - def decode_fn(self): - return self._decode_fn - - @abc.abstractmethod - def _compute_logits( - self, - params: PyTree, - batch: Mapping[str, jnp.ndarray], - dropout_rng: Optional[jax.Array] = None, - ) -> jnp.ndarray: - """Computes logits via a forward pass of the model.""" - pass - - def loss_fn( - self, - params: PyTree, - batch: Mapping[str, jnp.ndarray], - dropout_rng: Optional[jax.Array], - ) -> Tuple[jnp.ndarray, MetricsMap]: - """Loss function used for training with a cross-entropy loss.""" - logits = self._compute_logits(params, batch, dropout_rng) - - loss_normalizing_factor: Optional[ - Union[float, int, str, losses.SpecialLossNormalizingFactor] - ] - (loss_normalizing_factor, weights) = ( - losses.get_loss_normalizing_factor_and_weights( - self._loss_normalizing_factor, batch - ) - ) - - loss, z_loss, _ = losses.compute_weighted_cross_entropy( - logits, - targets=batch['decoder_target_tokens'], - weights=weights, - label_smoothing=self._label_smoothing, - z_loss=self._z_loss, - loss_normalizing_factor=loss_normalizing_factor, - ) - - # segment ids to compute packing, padding etc. - segment_ids = { - k[: -len('_segment_ids')]: v - for k, v in batch.items() - if k.endswith('_segment_ids') - } - # If these don't exist then we can create only padding mask. - if not segment_ids: - segment_ids = { - k: v != 0 - for k, v in batch.items() - if k in ('encoder_input_tokens', 'decoder_target_tokens') - } - - metrics = self._compute_metrics( - logits=logits, - targets=batch['decoder_target_tokens'], - mask=weights, - loss=loss, - z_loss=z_loss, - segment_ids=segment_ids, - ) - return loss, metrics - - def _compute_metrics( - self, - logits: jnp.ndarray, - targets: jnp.ndarray, - mask: jnp.ndarray, - loss: jnp.ndarray, - z_loss: Optional[jnp.ndarray] = None, - segment_ids: Optional[Mapping[str, jnp.ndarray]] = None, - ) -> MetricsMap: - return compute_base_metrics( - logits=logits, - targets=targets, - mask=mask, - loss=loss, - z_loss=z_loss, - segment_ids=segment_ids, - ) - - -@dataclasses.dataclass(frozen=True) -class DecoderParams: - return_all_decodes: bool = False - num_decodes: int = 1 - - -class EncoderDecoderModel(BaseTransformerModel): - """Wrapper class for the models.Transformer nn.module.""" - - FEATURE_CONVERTER_CLS = seqio.EncDecFeatureConverter - - def __init__( - self, - module: nn.Module, - input_vocabulary: seqio.Vocabulary, - output_vocabulary: seqio.Vocabulary, - optimizer_def: optimizers.OptimizerDefType, - decode_fn: DecodeFnCallable = decoding.beam_search, - feature_converter_cls: Optional[ - Callable[..., seqio.FeatureConverter] - ] = None, - label_smoothing: float = 0.0, - z_loss: float = 0.0, - loss_normalizing_factor: Optional[ - Union[float, int, str, losses.SpecialLossNormalizingFactor] - ] = None, - default_decoder_params: Optional[DecoderParams] = None, - ): - if feature_converter_cls is not None: - self.FEATURE_CONVERTER_CLS = feature_converter_cls # pylint: disable=invalid-name - self._default_decoder_params = default_decoder_params or DecoderParams() - super().__init__( - module=module, - input_vocabulary=input_vocabulary, - output_vocabulary=output_vocabulary, - optimizer_def=optimizer_def, - decode_fn=decode_fn, - label_smoothing=label_smoothing, - z_loss=z_loss, - loss_normalizing_factor=loss_normalizing_factor, - ) - - def get_initial_variables( - self, - rng: jax.Array, - input_shapes: Mapping[str, Array], - input_types: Optional[Mapping[str, jnp.dtype]] = None, - ) -> flax_scope.FrozenVariableDict: - """Get the initial variables for an encoder-decoder model.""" - input_types = {} if input_types is None else input_types - encoder_shape = input_shapes['encoder_input_tokens'] - encoder_type = input_types.get('encoder_input_tokens', jnp.float32) - decoder_shape = input_shapes['decoder_input_tokens'] - decoder_type = input_types.get('decoder_input_tokens', jnp.float32) - if 'encoder_positions' in input_shapes: - encoder_positions = jnp.ones( - input_shapes['encoder_positions'], - input_types.get('encoder_positions', jnp.int32), - ) - else: - encoder_positions = None - if 'decoder_positions' in input_shapes: - decoder_positions = jnp.ones( - input_shapes['decoder_positions'], - input_types.get('decoder_positions', jnp.int32), - ) - else: - decoder_positions = None - if 'encoder_segment_ids' in input_shapes: - encoder_segment_ids = jnp.ones( - input_shapes['encoder_segment_ids'], - input_types.get('encoder_segment_ids', jnp.int32), - ) - else: - encoder_segment_ids = None - if 'decoder_segment_ids' in input_shapes: - decoder_segment_ids = jnp.ones( - input_shapes['decoder_segment_ids'], - input_types.get('decoder_segment_ids', jnp.int32), - ) - else: - decoder_segment_ids = None - initial_variables = flax_core.freeze( - self.module.init( - rng, - jnp.ones(encoder_shape, encoder_type), - jnp.ones(decoder_shape, decoder_type), - jnp.ones(decoder_shape, decoder_type), - encoder_positions=encoder_positions, - decoder_positions=decoder_positions, - encoder_segment_ids=encoder_segment_ids, - decoder_segment_ids=decoder_segment_ids, - decode=False, - enable_dropout=False, - ) - ) - return initial_variables - - def _compute_logits( # pytype: disable=signature-mismatch # jax-ndarray - self, - params: PyTree, - batch: Mapping[str, jnp.ndarray], - dropout_rng: Optional[jax.Array] = None, - mutable: flax_scope.CollectionFilter = False, - other_variables: Optional[PyTree] = None, - ) -> Union[jnp.ndarray, Tuple[jnp.ndarray, flax_scope.FrozenVariableDict]]: - """Computes logits via a forward pass of `self.module_cls`.""" - # Dropout is provided only for the training mode. - rngs = {'dropout': dropout_rng} if dropout_rng is not None else None - if other_variables is None: - other_variables = {} - return self.module.apply( - {'params': params, **other_variables}, - batch['encoder_input_tokens'], - batch['decoder_input_tokens'], - batch['decoder_target_tokens'], - encoder_segment_ids=batch.get('encoder_segment_ids', None), - decoder_segment_ids=batch.get('decoder_segment_ids', None), - encoder_positions=batch.get('encoder_positions', None), - decoder_positions=batch.get('decoder_positions', None), - decode=False, - enable_dropout=rngs is not None, - rngs=rngs, - mutable=mutable, - ) - - def _compute_logits_from_slice( - self, - decoding_state: decoding.DecodingState, - params: PyTree, - encoded_inputs: jnp.ndarray, - raw_inputs: jnp.ndarray, - max_decode_length: int, - ) -> Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]: - """Token slice to logits from decoder model.""" - flat_ids = decoding_state.cur_token - flat_cache = decoding_state.cache - - # flat_ids: [batch * beam, seq_len=1] - # cache is expanded inside beam_search to become flat_cache - # flat_cache: [batch * beam, num_heads, depth_per_head, max_decode_len] - # flat_logits: [batch * beam, seq_len=1, vocab] - flat_logits, new_vars = self.module.apply( - {'params': params, 'cache': flat_cache}, - encoded_inputs, - raw_inputs, # only needed for encoder padding mask - flat_ids, - flat_ids, - enable_dropout=False, - decode=True, - max_decode_length=max_decode_length, - mutable=['cache'], - method=self.module.decode, - ) - # Remove sequence length dimension since it's always 1 during decoding. - flat_logits = jnp.squeeze(flat_logits, axis=1) - new_flat_cache = new_vars['cache'] - return flat_logits, new_flat_cache - - def _compute_kv_cache( - self, - params, - encoded_inputs: jnp.ndarray, - encoder_input_tokens: jnp.ndarray, - decoder_input_tokens: jnp.ndarray, - prefill_decoder_prompt: bool = False, - ) -> Tuple[PyTree, Optional[jnp.ndarray]]: - """Initialize the key/value cache, with optional prompt. - - Args: - params: The parameters of the model. - encoded_inputs: Output of the encoder on the inputs. - encoder_input_tokens: Input tokens for the encoder. Only needed for - padding mask. - decoder_input_tokens: Input tokens for the decoder, possibly containing a - prompt. - prefill_decoder_prompt: Whether to prefill the cache using the decoder - prompt. - - Returns: - cache: The initialzed cache. - initial_index: The index of the next position following prefill or None if - `prefill_decoder_prompt` is False. - """ - _, initial_variables = self.module.apply( - {'params': params}, - encoder_input_tokens=jnp.ones_like(encoder_input_tokens), - decoder_input_tokens=jnp.ones_like(decoder_input_tokens), - decoder_target_tokens=jnp.ones_like(decoder_input_tokens), - mutable=['cache'], - decode=True, - enable_dropout=False, - ) - - cache = initial_variables['cache'] - - if not prefill_decoder_prompt: - return cache, None - - # Prefill the cache based on an (optional) prompt. - # We assume the only 0 tokens are a BOS=0 token at the beginning of the - # input and PAD=0 tokens at the end. - inputs_lengths = jnp.sum(decoder_input_tokens != 0, axis=1) - - _, variables_with_cache = self.module.apply( - {'params': params, 'cache': cache}, - encoded=encoded_inputs, - encoder_input_tokens=encoder_input_tokens, # only for padding mask, - decoder_input_tokens=decoder_input_tokens, - decoder_target_tokens=jnp.ones_like(decoder_input_tokens), # for shape - mutable=['cache'], - enable_dropout=False, - prefill=True, - prefill_lengths=inputs_lengths, - method=self.module.decode, - ) - - cache = variables_with_cache['cache'] - if 'position_embedder' in cache['decoder']: - # TODO(adarob): Instead have `module.decode` accept an index. - cache['decoder']['position_embedder'][ - 'position_embedder_index' - ] = inputs_lengths - - return cache, inputs_lengths - - def predict_batch_with_aux( - self, - params: PyTree, - batch: Mapping[str, jnp.ndarray], - rng: Optional[jax.Array] = None, - decoder_params: Optional[MutableMapping[str, Any]] = None, - return_all_decodes: bool = None, - num_decodes: int = None, # pytype:disable=annotation-type-mismatch - prompt_with_targets: bool = False, - ) -> Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]: - """Predict with fast decoding beam search on a batch. - - Here we refer to "parameters" for values that can be compiled into the - model dynamically, as opposed to static configuration settings that require - a recompile. For example, the model weights and the decoder brevity-penalty - are parameters and can be modified without requiring a recompile. The number - of layers, the batch size and the decoder beam size are configuration - options that require recompilation if changed. - - This method can be used with a customizable decoding function as long as it - follows the signature of `DecodeFnCallable`. In order to provide a unified - interface for the decoding functions, we use generic names. For example, a - beam size is a concept unique to beam search. Conceptually, it corresponds - to the number of sequences returned by the beam search. Therefore, the - generic argument `num_decodes` corresponds to the beam size if - `self._decode_fn` is a beam search. For temperature sampling, `num_decodes` - corresponds to the number of independent sequences to be sampled. Typically - `num_decodes = 1` is used for temperature sampling. - - If `return_all_decodes = True`, the return tuple contains the predictions - with a shape [batch, num_decodes, max_decode_len] and the scores (i.e., log - probability of the generated sequence) with a shape [batch, num_decodes]. - The beam dimension is sorted in increasing order of log-probability. - - If `return_all_decodes = False`, the return tuple contains the predictions - with a shape [batch, max_decode_len] and the scores with a shape [batch]. - - `decoder_params` can be used to pass dynamic configurations to - `self.decode_fn`. An example usage is to pass different random seed (i.e., - `jax.random.PRNGKey(seed)` with different `seed` value). This can be done by - setting `decoder_params['decode_rng'] = jax.random.PRNGKey(seed)`. - - If `prompt_with_targets = True`, then `decoder_prompt_inputs` is initialized - from the batch's `decoder_input_tokens`. The EOS is stripped to avoid - decoding to stop after the prompt by matching to `output_vocabulary.eos_id`. - - Args: - params: model parameters. - batch: a batch of inputs. - rng: an optional RNG key to use during prediction, which is passed as - 'decode_rng' to the decoding function. - decoder_params: additional (model-independent) parameters for the decoder. - return_all_decodes: whether to return the entire beam or just the top-1. - num_decodes: the number of beams to use in beam search. - prompt_with_targets: Whether the force decode decoder_inputs. - - Returns: - A tuple containing: - the batch of predictions, with the entire beam if requested - an auxiliary dictionary of decoder scores - """ - if return_all_decodes is None: - return_all_decodes = self._default_decoder_params.return_all_decodes - if num_decodes is None: - num_decodes = self._default_decoder_params.num_decodes - - # [batch, input_len] - encoder_input_tokens = batch['encoder_input_tokens'] - decoder_input_tokens = batch['decoder_input_tokens'] - - # `decoder_prompt_inputs` is initialized from the batch's - # `decoder_input_tokens`. The EOS is stripped to avoid decoding to stop - # after the prompt by matching to `output_vocabulary.eos_id`. - # These inputs are ignored by the beam search decode fn. - if prompt_with_targets: - decoder_prompt_inputs = decoder_input_tokens - decoder_prompt_inputs = decoder_prompt_inputs * ( - decoder_prompt_inputs != self.output_vocabulary.eos_id - ) - else: - decoder_prompt_inputs = jnp.zeros_like(decoder_input_tokens) - - encoded_inputs = self.module.apply( - {'params': params}, - encoder_input_tokens, - enable_dropout=False, - method=self.module.encode, - ) - - # Prepare autoregressive cache. - if 'initial_index' not in inspect.signature(self.decode_fn).parameters: - logging.info( - 'Disabling prompt prefilling due to incompatible decode fn: %s.', - self.decode_fn, - ) - prefill_decoder_prompt = False - elif 'prefill' not in inspect.signature(self.module.decode).parameters: - logging.info( - 'Disabling prompt prefilling due to incompatible `module.decode`.' - ) - prefill_decoder_prompt = False - else: - logging.info('Enabling prompt prefilling.') - prefill_decoder_prompt = True - cache, initial_index = self._compute_kv_cache( - params, - encoded_inputs=encoded_inputs, - encoder_input_tokens=encoder_input_tokens, - decoder_input_tokens=decoder_prompt_inputs, - prefill_decoder_prompt=prefill_decoder_prompt, - ) - - # Prepare transformer fast-decoder call for beam search: for beam search, we - # need to set up our decoder model to handle a batch size equal to - # batch_size * num_decodes, where each batch item's data is expanded - # in-place rather than tiled. - # i.e. if we denote each batch element subtensor as el[n]: - # [el0, el1, el2] --> beamsize=2 --> [el0,el0,el1,el1,el2,el2] - # [batch * num_decodes, input_len, emb_dim] - tokens_ids_to_logits = functools.partial( - self._compute_logits_from_slice, - params=params, - # [batch * num_decodes, input_len, emb_dim] - encoded_inputs=decoding.flat_batch_beam_expand( - encoded_inputs, num_decodes - ), - # [batch * num_decodes, input_len] - raw_inputs=decoding.flat_batch_beam_expand( - encoder_input_tokens, num_decodes - ), - max_decode_length=decoder_input_tokens.shape[1], - ) - - if decoder_params is None: - decoder_params = {} - if initial_index is not None: - # We only set initial_index when it's non-None since it is not supported - # by all decoders. - decoder_params['initial_index'] = initial_index - - if rng is not None: - if decoder_params.get('decode_rng') is not None: - raise ValueError( - f'Got RNG both from the `rng` argument ({rng}) and' - " `decoder_params['decode_rng']`" - f' ({decoder_params["decode_rng"]}). Please specify one or the' - ' other.' - ) - decoder_params['decode_rng'] = rng - - # TODO(hwchung): rename the returned value names to more generic ones. - # Using the above-defined single-step decoder function, run a - # beam search over possible sequences given input encoding. - # decodes: [batch, num_decodes, max_decode_len + 1] - # scores: [batch, num_decodes] - scanned = hasattr(self.module, 'scan_layers') and self.module.scan_layers - - if 'eos_id' not in decoder_params: - decoder_params['eos_id'] = self.output_vocabulary.eos_id or 1 - decodes, scores = self._decode_fn( - inputs=decoder_prompt_inputs, - cache=cache, - tokens_to_logits=tokens_ids_to_logits, - num_decodes=num_decodes, - cache_offset=1 if scanned else 0, - **decoder_params, - ) - - # Beam search returns [n_batch, n_beam, n_length] with beam dimension sorted - # in increasing order of log-probability. - # Return the highest scoring beam sequence. - if return_all_decodes: - return decodes, {'scores': scores} - else: - return decodes[:, -1, :], {'scores': scores[:, -1]} - - def score_batch( # pytype: disable=signature-mismatch # jax-ndarray - self, - params: PyTree, - batch: Mapping[str, jnp.ndarray], - return_intermediates: bool = False, - ) -> Union[jnp.ndarray, Tuple[jnp.ndarray, Mapping[str, Any]]]: - """Compute log likelihood score on a batch.""" - weights = batch['decoder_loss_weights'] - target_tokens = batch['decoder_target_tokens'] - - if return_intermediates: - logits, modified_variables = self._compute_logits( - params=params, batch=batch, mutable=['intermediates'] - ) - - # Inside self.module, we called nn.Module.sow to track various - # intermediate values. We extract them here. - intermediates = flax_core.unfreeze( - modified_variables.get('intermediates', {}) - ) - - # Track per-token labels and loss weights as well. These are not - # intermediate values of logit computation, so we manually add them here. - intermediates.setdefault('decoder', {}) - intermediates['decoder']['target_tokens'] = (target_tokens,) - intermediates['decoder']['loss_weights'] = (weights,) - # Note that the values are singleton tuples. This is because values inside - # `intermediates` should be tuples tracking all instantiations of a value. - # These values each have just one instantiation, hence singletons. - else: - logits = self._compute_logits(params, batch) # type: jnp.ndarray # pytype: disable=annotation-type-mismatch # jax-ndarray - - # Purposefully don't use config.z_loss because that term is for training - # stability and shouldn't affect our reported scores. - token_scores = ( - -losses.cross_entropy_with_logits( - logits, - common_utils.onehot( - target_tokens, logits.shape[-1], on_value=1, off_value=0 - ), - z_loss=0.0, - )[0] - * weights - ) - if return_intermediates: - intermediates['decoder']['token_scores'] = (token_scores,) - - sequence_scores = token_scores.sum(-1) - - if return_intermediates: - return sequence_scores, intermediates - - return sequence_scores - - -class DecoderOnlyModel(BaseTransformerModel): - """Model class for the decoder-only modules. - - It accepts inputs made out of only 'targets' or both 'inputs' - and 'targets'. If both 'inputs' and 'targets' are present, the loss will - be computed only on 'targets'. - - By default the self-attention is fully causal and a given position only - attends to the time steps before and itself. If - `inputs_bidirectional_attention = True`, the attention in the "inputs" region - is bidirectional. This architecture was referred to as "Prefix LM" in Raffel - et al. 2019 (https://arxiv.org/abs/1910.10683). - """ - - FEATURE_CONVERTER_CLS = seqio.DecoderFeatureConverter - - def __init__( - self, - module: nn.Module, - vocabulary: seqio.Vocabulary, - optimizer_def: optimizers.OptimizerDefType, - decode_fn: DecodeFnCallable = decoding.temperature_sample, - inputs_bidirectional_attention: bool = False, - feature_converter_cls: Optional[ - Callable[..., seqio.FeatureConverter] - ] = None, - label_smoothing: float = 0.0, - z_loss: float = 0.0, - loss_normalizing_factor: Optional[ - Union[float, int, str, losses.SpecialLossNormalizingFactor] - ] = None, - ): - if feature_converter_cls is not None: - self.FEATURE_CONVERTER_CLS = feature_converter_cls # pylint: disable=invalid-name - self._inputs_bidirectional_attention = inputs_bidirectional_attention - super().__init__( - module, - input_vocabulary=vocabulary, - output_vocabulary=vocabulary, - optimizer_def=optimizer_def, - decode_fn=decode_fn, - label_smoothing=label_smoothing, - z_loss=z_loss, - loss_normalizing_factor=loss_normalizing_factor, - ) - - def get_initial_variables( - self, - rng: jax.Array, - input_shapes: Mapping[str, Array], - input_types: Optional[Mapping[str, jnp.dtype]] = None, - ) -> flax_scope.FrozenVariableDict: - """Get the initial variables.""" - input_types = {} if input_types is None else input_types - decoder_shape = input_shapes['decoder_input_tokens'] - decoder_type = input_types.get('decoder_input_tokens', jnp.float32) - initial_variables = self.module.init( - rng, - jnp.ones(decoder_shape, decoder_type), - jnp.ones(decoder_shape, decoder_type), - enable_dropout=False, - ) - return flax_core.freeze(initial_variables) - - def _get_decoder_causal_attention(self, batch): - """Returns decoder causal attention from the batch or None.""" - if self._inputs_bidirectional_attention: - if 'decoder_causal_attention' not in batch: - raise ValueError( - '`inputs_bidirectional_attention` mode requires ' - '"decoder_causal_attention" feature in the batch' - ) - decoder_causal_attention = batch['decoder_causal_attention'] - else: - decoder_causal_attention = None - - return decoder_causal_attention - - def _compute_logits( - self, - params: PyTree, - batch: Mapping[str, jnp.ndarray], - dropout_rng: Optional[jax.Array] = None, - mutable: flax_scope.CollectionFilter = False, - other_variables: Optional[PyTree] = None, - ) -> jnp.ndarray: - """Computes logits via a forward pass of `self.module`.""" - rngs = {'dropout': dropout_rng} if dropout_rng is not None else None - decoder_causal_attention = self._get_decoder_causal_attention(batch) - if other_variables is None: - other_variables = {} - - return self.module.apply( - {'params': params, **other_variables}, - batch['decoder_input_tokens'], - batch['decoder_target_tokens'], - decoder_segment_ids=batch.get('decoder_segment_ids', None), - decoder_positions=batch.get('decoder_positions', None), - decoder_causal_attention=decoder_causal_attention, - rngs=rngs, - decode=False, - enable_dropout=rngs is not None, - mutable=mutable, - ) - - def _compute_logits_from_slice( - self, - decoding_state: decoding.DecodingState, - params: PyTree, - max_decode_length: int, - ) -> Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]: - """Token slice to logits from decoder model.""" - flat_ids = decoding_state.cur_token - flat_cache = decoding_state.cache - # flat_ids: [batch, seq_len=1] - # flat_cache['cached_(keys|values)']: - # [batch, num_heads, depth_per_head, max_decode_length] - # flat_cache['cache_index']: [batch] - # flat_logits: [batch, seq_len=1, vocab] - flat_logits, new_vars = self.module.apply( - {'params': params, 'cache': flat_cache}, - flat_ids, - flat_ids, - enable_dropout=False, - decode=True, - max_decode_length=max_decode_length, - mutable=['cache'], - ) - # Remove sequence length dimension since it's always 1 during decoding. - flat_logits = jnp.squeeze(flat_logits, axis=1) - new_flat_cache = new_vars['cache'] - return flat_logits, new_flat_cache - - def score_batch( - self, - params: PyTree, - batch: Mapping[str, jnp.ndarray], - return_intermediates: bool = False, - ) -> jnp.ndarray: - """Compute log likelihood score on a batch.""" - - decoder_target_tokens = batch['decoder_target_tokens'] - weights = batch['decoder_loss_weights'] - - if return_intermediates: - logits, modified_variables = self._compute_logits( - params=params, - batch=batch, - dropout_rng=None, - mutable=['intermediates'], - ) - - # Inside self.module, we called nn.Module.sow to track various - # intermediate values. We extract them here. - intermediates = flax_core.unfreeze( - modified_variables.get('intermediates', {}) - ) - - # Track per-token labels and loss weights as well. These are not - # intermediate values of logit computation, so we manually add them here. - intermediates.setdefault('decoder', {}) - intermediates['decoder']['target_tokens'] = (decoder_target_tokens,) - intermediates['decoder']['loss_weights'] = (weights,) - # Note that the values are singleton tuples. This is because values inside - # `intermediates` should be tuples tracking all instantiations of a value. - # These values each have just one instantiation, hence singletons. - else: - logits = self._compute_logits( - params=params, batch=batch, dropout_rng=None - ) - - token_scores = ( - -losses.cross_entropy_with_logits( - logits, - common_utils.onehot( - decoder_target_tokens, logits.shape[-1], on_value=1, off_value=0 - ), - z_loss=0.0, - )[0] - * weights - ) - if return_intermediates: - intermediates['decoder']['token_scores'] = (token_scores,) - - sequence_scores = token_scores.sum(-1) - - if return_intermediates: - return sequence_scores, intermediates # pytype: disable=bad-return-type # jax-ndarray - - return sequence_scores - - def _compute_kv_cache( - self, - params: PyTree, - inputs: jnp.ndarray, - causal_attention_mask: jnp.ndarray, - ) -> Tuple[PyTree, jnp.ndarray]: - """Compute the key/value cache on the input prompt. - - Args: - params: The parameters of the model. - inputs: Tokens to use for prompting, with 0-padding. - causal_attention_mask: A boolean mask containing 1 at positions that are - treated as inputs. - - Returns: - cache: The prefilled cache. - initial_index: The index of the next position following prefill. - """ - # The lengths of the inputs match the number of non-padding positions, - # excluding the initial BOS. - inputs_lengths = jnp.sum(inputs[:, 1:] != 0, axis=-1) - - _, initial_variables = self.module.apply( - {'params': params}, - jnp.ones_like(inputs), - jnp.ones_like(inputs), - enable_dropout=False, - decode=True, - mutable=['cache'], - ) - cache = initial_variables['cache'] - if 'cache_axes' in initial_variables: - cache_axes = initial_variables['cache_axes'] - - cache = jax.tree_util.tree_map( - flax_partitioning.with_sharding_constraint, - cache, - flax_partitioning.get_axis_names(cache_axes), - ) - - # Prefill our cache with all the inputs. `inputs_lengths` is the index of - # the last input token. The cache will be filled for all the input - # positions, save the last input token. The cache index will point to the - # index of this last input token which is considered during prefilling but - # not cached. This re-computation is required as the logits for this - # position are required for selecting the first output token. - # - # The cache is still `[B, ..., max_decode_len]` but any position less than - # the `inputs_length` will be non-zero, that is - # `cached_key[b, ..., i < inputs_lengths[b]] != 0`. - # - # The cache index is now a vector of size [B] = input_lengths - - # If `self._inputs_bidirectional_attention = False`, we should not pass - # batch['decoder_causal_attention'] to `module.apply` during cache prefill - # and pass None instead. - maybe_causal_attention_mask = self._get_decoder_causal_attention( - {'decoder_causal_attention': causal_attention_mask} - ) - - _, variables_with_cache = self.module.apply( - {'params': params, 'cache': cache}, - decoder_input_tokens=inputs, - # Use the `decoder_causal_attention`, which has 1 for all input - # positions, including the BOS token, as the targets so when the - # decoder attention mask is built, it will correctly cover the whole - # input, Using something like the inputs will cause the first input - # token (the 0 for BOS) will not be included in the mask. This also - # restricts the mask to not include any target positions like it would - # if you used `decoder_target_tokens`. - decoder_target_tokens=causal_attention_mask, - decoder_causal_attention=maybe_causal_attention_mask, - mutable=['cache'], - enable_dropout=False, - prefill=True, - prefill_lengths=inputs_lengths, - ) - return variables_with_cache['cache'], inputs_lengths - - def predict_batch_with_aux( - self, - params: PyTree, - batch: Mapping[str, jnp.ndarray], - rng: Optional[jax.Array] = None, - *, - return_all_decodes: bool = False, - num_decodes: int = 1, - decoder_params: Optional[MutableMapping[str, Any]] = None, - ) -> Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]: - """Predict with prefix. - - `decoder_params` can be used to pass dynamic configurations to - `self.decode_fn`. An example usage is to pass different random seed (i.e., - `jax.random.PRNGKey(seed)` with different `seed` value). This can be done by - setting `decoder_params['decode_rng'] = jax.random.PRNGKey(seed)`. - - Although this method is short, there are a few subtle points that. We use a - running example to make these points clear. - - ``` - Example - inputs = [9, 4, 6, 1] - targets = [3, 9, 1] - - seqio.DecoderFeatureConverter will generate these set of features - - decoder_target_tokens = [9, 4, 6, 1, 3, 9, 1, 0, 0] - decoder_input_tokens = [0, 9, 4, 6, 1, 3, 9, 1, 0] - decoder_causal_attention = [1, 1, 1, 1, 1, 0, 0, 0, 0] - - The output of this function is (a` through `e` are the sampled token ids): - - sampled_sequences = [9, 4, 6, 1, a, b, c, d, e]. - ``` - - Given these set of features, we make a few important observation. - - 1) When a decoder-only model is used for a supervised learning with "inputs" - and "targets", one way to handle this is to concatenate the "inputs" and - "targets". For training, we use teacher forcing for the entire - concatenated sequence. For inference, on the other hand, we don't have - the targets. This requires that we use teacher forcing on the "inputs" - portion while using the generated token as the input token for the next - decoding step. For evaluation, we do have "targets" but we only want to - use them for computing metrics, i.e., by comparing to the sequence - generated by the model. - - This function is currently used for evaluation mode, but by ignoring - "targets", it can be extended for the inference mode. - - 2) During evaluation mode, the targets portion is zeroed out and they are - filled with the sampled token ids. The inputs portion is kept intact. - - 3) Note that `decoder_causal_attention` has an additional 1 after the final - "inputs" token. This is because the position where the last "inputs" - token (in this case 1) is input and the output is the first "target" - token (in this case 3) can be included in the non-causal attention - region. - - This results in an alignment between `decoder_input_tokens` and - `decoder_causal_attention` because the former is shifted to the right by - one position. So we use `decoder_causal_attention` as a binary mask to - zero out the target tokens in `decoder_input_tokens`. - - Note: - In order to use a custom self._decode_fn with this model it must support: - - 1) Decoding from a partially decoded state by accepting a vector of - `initial_indices` that specify where in the input to start decoding - from. - 2) Using a vector as the loop counter to support different examples being - a different number of steps into their decoding loop. - 3) Be able to handle one batch element reaching `max_decode_length` - before the others without it causing the model to prematurely stop - decoding. - - Args: - params: model parameters. - batch: batch element with the model features specified in - seqio.DecoderFeatureConverter. - rng: an optional RNG key to use during prediction, which is passed as - 'decode_rng' to the decoding function. - return_all_decodes: if True, will return all batch_size * num_decodes - samples from the model as an array of shape [batch_size, num_decodes, - sequence_length]. In this case the `num_decodes` dimension is sorted in - increasing order of log-probability. Otherwise returns only the most - likely samples as an array of shape [batch_size, sequence_length]. - num_decodes: number of decoded sequences to be returned. - decoder_params: additional (model-independent) parameters for the decoder. - - Returns: - sampled_sequences: an array of shape [batch, max_decode_length]. - """ - if 'decoder_causal_attention' not in batch: - raise ValueError( - 'Batch does not have the right format for text generation: probably ' - 'because `task_feature_lengths` passed to the feature converter does ' - 'not have both `inputs` and `targets`.' - ) - - # since decoder_input_tokens is shifted to the right and - # `decoder_causal_attention` has one more 1 than the number of inputs - # tokens, this masks out targets portion of the decoder_input_tokens. - inputs = batch['decoder_input_tokens'] * batch['decoder_causal_attention'] - - prefilled_cache, initial_index = self._compute_kv_cache( - params, inputs, batch['decoder_causal_attention'] - ) - - target_shape = batch['decoder_input_tokens'].shape - max_decode_length = target_shape[1] - - tokens_ids_to_logits = functools.partial( - self._compute_logits_from_slice, - params=params, - max_decode_length=max_decode_length, - ) - - if decoder_params is None: - decoder_params = {} - if rng is not None: - if decoder_params.get('decode_rng') is not None: - raise ValueError( - f'Got RNG both from the `rng` argument ({rng}) and' - " `decoder_params['decode_rng']`" - f' ({decoder_params["decode_rng"]}). Please specify one or the' - ' other.' - ) - decoder_params['decode_rng'] = rng - - # Using the above-defined single-step decoder function, run temperature - # sampling with the prefix. - # [batch, max_decode_length] - scanned = hasattr(self.module, 'scan_layers') and self.module.scan_layers - - if 'eos_id' not in decoder_params: - decoder_params['eos_id'] = self.output_vocabulary.eos_id or 1 - decoded_sequences, scores = self._decode_fn( - inputs=inputs, - cache=prefilled_cache, - tokens_to_logits=tokens_ids_to_logits, - num_decodes=num_decodes, - initial_index=initial_index, - cache_offset=1 if scanned else 0, - **decoder_params, - ) - - if not return_all_decodes: - # Search returns [n_batch, n_beam/decodes, n_length] with the beam/decode - # dimension sorted in increasing order of log-probability. - # `scores` is [batch, beam/decode_size] - # We take the highest scoring sequence (-1) and its score - decoded_sequences = decoded_sequences[:, -1, :] - # Beam search returns [] - aux = {'scores': scores[:, -1]} - else: - # We return all samples and scores, rather than just the top ones. - aux = {'scores': scores} - - return remove_prefix(decoded_sequences, initial_index), aux - - -@jax.vmap -def remove_prefix( - sequence: jnp.ndarray, prefix_length: jnp.ndarray -) -> jnp.ndarray: - """Remove the prefix portion and shift to the left by the prefix length. - - The example below uses non-decorated function definition, i.e., arrays do not - have batch dimension. `jax.vmap` internally inserts the batch dimension at - axis=0. The shape annotations do not include the batch dimension either. - - Example: - ```python - sequence = [1, 2, 3, 4, 5, 6, 7, 0] - prefix_length = 2 - remove_prefix(sequence, prefix_length) = [3, 4, 5, 6, 7, 0, 0, 0] - ``` - - Note that this function assumes that the padding token has an id of 0. - - Args: - sequence: [length] array. - prefix_length: scalar, i.e., rank 0 array. - - Returns: - [length] array with the prefix removed and the suffix shifted. - """ - length = sequence.shape[-1] - # A binary mask with 1 at inputs. - inputs_mask = jnp.arange(length) < prefix_length - # A binary mask with 1 at the targets and padding positions. - targets_and_padding_mask = jnp.logical_not(inputs_mask).astype(sequence.dtype) - # Since padding id = 0, the padding mask is zeroed out. - targets = sequence * targets_and_padding_mask - # Shift to the left by prefix length. Wrapped elements are already zeroed. - return jnp.roll(targets, -prefix_length, axis=-1) - - -# TODO(cpgaffney) Remove this method when dependencies no longer use - rely on -# WeightedAccuracy Metric instead. -def compute_weighted_accuracy( - logits: jnp.ndarray, - targets: jnp.ndarray, - weights: Optional[jnp.ndarray] = None, -) -> Tuple[jnp.ndarray, jnp.ndarray]: - """Compute weighted accuracy for log probs and targets. - - Args: - logits: [batch, length, num_classes] float array. - targets: categorical targets [batch, length] int array of categories. - weights: None or array of shape [batch, length] - - Returns: - Scalar accuracy. - """ - if logits.ndim != targets.ndim + 1: - raise ValueError( - 'Incorrect shapes. Got shape %s logits and %s targets' - % (str(logits.shape), str(targets.shape)) - ) - accuracy = jnp.equal(jnp.argmax(logits, axis=-1), targets) - if weights is not None: - accuracy = accuracy * weights - - return jnp.sum(accuracy) # pytype: disable=bad-return-type # jnp-type - - -# TODO(cpgaffney) remove when users rely on compute_base_metrics -def compute_metrics( - logits: jnp.ndarray, - targets: jnp.ndarray, - weights: jnp.ndarray, - loss: jnp.ndarray, - weight_sum: jnp.ndarray, - additional_metrics: MetricsMap, -) -> MetricsMap: - """Compute summary metrics.""" - accuracy = compute_weighted_accuracy(logits, targets, weights) - metrics = { - 'loss': loss, - 'accuracy': accuracy, - 'weight_sum': weight_sum, - 'num_examples': targets.shape[0], - 'num_tokens': targets.size, - } - metrics = metrics_lib.create_metrics_dict(metrics) - metrics.update(additional_metrics) - return metrics - - -def count_packed_examples(segment_ids: jnp.ndarray) -> int: - """Return the number of packed examples. - - After packing, each row of segment_ids contains the ids of packed examples. - For some model inputs, some features could have some examples but not others. - For example, two tasks in a multimodal setup could be: (1). text -> text, and - (2). image -> text. Examples from (1) will be missing image input feature and - examples from (2) will be missing text input feature. - - To count the packed examples, we count the unique ids in segment_ids excluding - 0s (because of padding). It can be implemented by counting the number of - non-zero values in the first discrete difference along axis=1, plus the number - of rows in segment_ids, and minus the number of padded examples. - - Example: - [[1, 1, 3, 3, 0, 0], - [2, 2, 2, 2, 2, 2], - [2, 7, 7, 7, 7, 0]] has 5 packed examples. - - Args: - segment_ids: [B, L] array. - - Returns: - Scalar count. - """ - - # If there is padding, it's at the end and the id is always 0. - num_padded_examples = jnp.sum(segment_ids[:, -1] == 0) - # Get the first discrete different along axis=1. - first_diff = jnp.diff(segment_ids, n=1, axis=1) - # count = #(non-0 diff) + #(row) - #(padded ex). - return jnp.sum(first_diff != 0) + segment_ids.shape[0] - num_padded_examples # pytype: disable=bad-return-type # jnp-type - - -def compute_base_metrics( - logits: jnp.ndarray, - targets: jnp.ndarray, - mask: jnp.ndarray, - loss: jnp.ndarray, - z_loss: Optional[jnp.ndarray] = None, - segment_ids: Optional[Mapping[str, jnp.ndarray]] = None, -) -> MetricsMap: - """Compute summary metrics. - - Args: - logits: [batch, length, num_classes] float array. - targets: categorical targets [batch, length] int array of categories. - mask: None or array of shape [batch, length]. Note: must consist of boolean - values (float-valued weights not supported). - loss: loss (float) - z_loss: z_loss (float) - segment_ids: Optional dictionary of feature and value is the segment ids used - for packing, i.e. [batch, length] arrays. - - Returns: - Dict of metrics. - """ - num_examples = jnp.array(targets.shape[0]) - num_tokens = jnp.array(targets.size) - num_devices = jax.device_count() - assert num_devices, 'JAX is reporting no devices, but it should.' - # Note: apply mask again even though mask has already been applied to loss. - # This is needed to divide by mask sum, but should not affect correctness of - # the numerator. - nonpadding_tokens = jnp.sum(mask) if mask is not None else targets.size - metrics = { - 'accuracy': clu_metrics.Accuracy.from_model_output( - logits=logits, labels=targets.astype(jnp.int32), mask=mask - ), - 'loss': metrics_lib.AveragePerStep(total=loss), - 'loss_per_nonpadding_target_token': clu_metrics.Average( - total=loss, count=nonpadding_tokens - ), - 'loss_per_all_target_tokens': clu_metrics.Average( - total=loss, count=num_tokens - ), - 'timing/seqs_per_second': metrics_lib.TimeRate.from_model_output( # pytype: disable=wrong-arg-types # jnp-type - numerator=num_examples - ), - 'timing/steps_per_second': metrics_lib.StepsPerTime.from_model_output(), - 'timing/seconds': metrics_lib.Time(), - 'timing/seqs': metrics_lib.Sum(num_examples), - 'timing/seqs_per_second_per_core': metrics_lib.TimeRate.from_model_output( # pytype: disable=wrong-arg-types # jnp-type - numerator=num_examples / num_devices - ), - 'timing/target_tokens_per_second': metrics_lib.TimeRate.from_model_output( # pytype: disable=wrong-arg-types # jnp-type - numerator=num_tokens - ), - 'timing/target_tokens_per_second_per_core': ( - metrics_lib.TimeRate.from_model_output( # pytype: disable=wrong-arg-types # jnp-type - numerator=num_tokens / num_devices - ) - ), - 'non_padding_fraction/loss_weights': clu_metrics.Average( - total=nonpadding_tokens, count=num_tokens - ), - } - if z_loss is not None: - metrics.update({ - 'z_loss': metrics_lib.AveragePerStep(total=z_loss), - 'z_loss_per_all_target_tokens': clu_metrics.Average( - total=z_loss, count=num_tokens - ), - 'cross_ent_loss': metrics_lib.AveragePerStep(total=loss - z_loss), - 'cross_ent_loss_per_all_target_tokens': clu_metrics.Average( - total=jnp.sum(loss - z_loss), count=num_tokens - ), - }) - - if segment_ids is not None: - total_tokens = jnp.array(0) - total_non_padding_tokens = jnp.array(0) - for feature, feature_segment_ids in segment_ids.items(): - if feature_segment_ids is None or feature_segment_ids.shape[1] == 0: - continue - # Since this is [B, L] with the segment ids in axis = 1. - num_examples = count_packed_examples(feature_segment_ids) - metrics[f'effective_batch_size/{feature}'] = metrics_lib.AveragePerStep( - total=num_examples - ) - # 0s is padding - feature_non_padding = jnp.sum(feature_segment_ids != 0) - feature_size = jnp.array(feature_segment_ids.size) - total_tokens += feature_size - total_non_padding_tokens += feature_non_padding - metrics[f'non_padding_fraction/{feature}'] = clu_metrics.Average( - total=feature_non_padding, count=feature_size - ) - metrics['non_padding_fraction/overall'] = clu_metrics.Average( - total=total_non_padding_tokens, count=total_tokens - ) - - return metrics - - -def get_input_vocabulary(model: BaseTransformerModel) -> seqio.Vocabulary: - return model.input_vocabulary - - -def get_output_vocabulary(model: BaseTransformerModel) -> seqio.Vocabulary: - return model.output_vocabulary diff --git a/t5x-main/t5x/models_test.py b/t5x-main/t5x/models_test.py deleted file mode 100644 index fb65fc0d8ec42bcb2a16d41a0a7e3e07e3894be9..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/models_test.py +++ /dev/null @@ -1,1151 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for t5x.models.""" - -import functools -from unittest import mock - -from absl import logging -from absl.testing import absltest -from absl.testing import parameterized -import flax -from flax import traverse_util -import jax -import jax.numpy as jnp -import numpy as np -import t5.data.tasks # pylint:disable=unused-import -from t5x import decoding -from t5x import models -from t5x import partitioning -from t5x import test_utils -from t5x import trainer as trainer_lib -from t5x import utils -import tensorflow as tf - -# Parse absl flags test_srcdir and test_tmpdir. -jax.config.parse_flags_with_absl() - -PartitionSpec = partitioning.PartitionSpec - - -class ModelsTest(parameterized.TestCase): - - def test_remove_prefix(self): - sequences = np.array([[1, 2, 3, 4, 5, 6, 7, 0], [6, 7, 8, 9, 10, 11, 0, 0]]) - prefix_lengths = np.array([2, 4]) - expected = [[3, 4, 5, 6, 7, 0, 0, 0], [10, 11, 0, 0, 0, 0, 0, 0]] - remove_prefix = jax.jit(models.remove_prefix) - actual = remove_prefix(sequences, prefix_lengths) - np.testing.assert_array_equal(actual, expected) - - def test_remove_prefix_zero_len_prefix(self): - sequences = np.array([[1, 2, 3, 4, 5, 6, 7, 0], [6, 7, 8, 9, 10, 11, 0, 0]]) - prefix_lengths = np.array([0, 0]) - remove_prefix = jax.jit(models.remove_prefix) - actual = remove_prefix(sequences, prefix_lengths) - # The expected output is the original sequences. - np.testing.assert_array_equal(actual, sequences) - - def test_count_packed_examples(self): - segment_ids_1 = np.array([[1, 1, 2, 2, 0, 0], [1, 1, 1, 1, 1, 1]]) - actual_1 = models.count_packed_examples(segment_ids_1) - # `actual_1` is DeviceArray(3, dtype=int32). - np.testing.assert_array_equal(actual_1, [3]) - - segment_ids_2 = np.array( - [[1, 1, 3, 3, 0, 0], [2, 2, 2, 2, 2, 2], [2, 7, 7, 7, 7, 0]] - ) - actual_2 = models.count_packed_examples(segment_ids_2) - # `actual_2` is DeviceArray(5, dtype=int32). - np.testing.assert_array_equal(actual_2, [5]) - - -BATCH_SIZE, ENCODER_LEN, MAX_DECODE_LEN, EMBED_DIM = 2, 3, 4, 5 - - -class EncoderDecoderModelTest(parameterized.TestCase): - - @parameterized.named_parameters( - dict( - testcase_name='no_types', - shapes={ - 'encoder_input_tokens': [1, 512], - 'decoder_input_tokens': [1, 62], - }, - types=None, - ), - dict( - testcase_name='int32', - shapes={ - 'encoder_input_tokens': [1, 512], - 'decoder_input_tokens': [1, 62], - }, - types={ - 'encoder_input_tokens': jnp.int32, - 'decoder_input_tokens': jnp.int32, - }, - ), - dict( - testcase_name='float32', - shapes={ - 'encoder_input_tokens': [1, 512], - 'decoder_input_tokens': [1, 62], - 'encoder_positions': [1, 512], - 'decoder_positions': [1, 62], - }, - types={ - 'encoder_input_tokens': jnp.int32, - 'decoder_input_tokens': jnp.int32, - 'encoder_positions': jnp.int32, - 'decoder_positions': jnp.int32, - }, - ), - dict( - testcase_name='float32_segment_ids', - shapes={ - 'encoder_input_tokens': [1, 512], - 'decoder_input_tokens': [1, 62], - 'encoder_segment_ids': [1, 512], - 'decoder_segment_ids': [1, 62], - }, - types={ - 'encoder_input_tokens': jnp.int32, - 'decoder_input_tokens': jnp.int32, - 'encoder_segment_ids': jnp.int32, - 'decoder_segment_ids': jnp.int32, - }, - ), - ) - def test_get_initial_variables_shapes_and_types(self, shapes, types): - mock_transformer = mock.Mock() - mock_transformer.init.return_value = {'params': {}} - mock_optimizer_def = mock.Mock() - rng = mock.Mock() - - def mock_init(self): - self.module = mock_transformer - self.optimizer_def = mock_optimizer_def - self._default_decoder_params = models.DecoderParams() - - with mock.patch.object( - models.EncoderDecoderModel, '__init__', new=mock_init - ): - model = models.EncoderDecoderModel() - model.get_initial_variables(rng, shapes, types) - - if types is None: - encoder_input = jnp.ones( - shapes['encoder_input_tokens'], dtype=jnp.float32 - ) - decoder_input = jnp.ones( - shapes['decoder_input_tokens'], dtype=jnp.float32 - ) - else: - encoder_input = jnp.ones( - shapes['encoder_input_tokens'], dtype=types['encoder_input_tokens'] - ) - decoder_input = jnp.ones( - shapes['decoder_input_tokens'], dtype=types['decoder_input_tokens'] - ) - - # Using `.assert_called_once_with` doesn't work because the simple - # comparison it does for the array arguments fail (truth value of an array - # is ambiguous). - called_with = mock_transformer.init.call_args - self.assertEqual(called_with[0][0], rng) - np.testing.assert_allclose(called_with[0][1], encoder_input) - np.testing.assert_allclose(called_with[0][2], decoder_input) - np.testing.assert_allclose(called_with[0][3], decoder_input) - - if 'encoder_positions' in shapes: - encoder_positions = jnp.ones( - shapes['encoder_positions'], dtype=types['encoder_positions'] - ) - np.testing.assert_allclose( - called_with[1]['encoder_positions'], encoder_positions - ) - else: - self.assertIsNone(called_with[1]['encoder_positions']) - if 'decoder_positions' in shapes: - decoder_positions = jnp.ones( - shapes['decoder_positions'], dtype=types['decoder_positions'] - ) - np.testing.assert_allclose( - called_with[1]['decoder_positions'], decoder_positions - ) - else: - self.assertIsNone(called_with[1]['decoder_positions']) - - if 'encoder_segment_ids' in shapes: - encoder_positions = jnp.ones( - shapes['encoder_segment_ids'], dtype=types['encoder_segment_ids'] - ) - np.testing.assert_allclose( - called_with[1]['encoder_segment_ids'], encoder_positions - ) - else: - self.assertIsNone(called_with[1]['encoder_segment_ids']) - if 'decoder_segment_ids' in shapes: - decoder_segment_ids = jnp.ones( - shapes['decoder_segment_ids'], dtype=types['decoder_segment_ids'] - ) - np.testing.assert_allclose( - called_with[1]['decoder_segment_ids'], decoder_segment_ids - ) - else: - self.assertIsNone(called_with[1]['decoder_segment_ids']) - - self.assertFalse(called_with[1]['decode']) - self.assertFalse(called_with[1]['enable_dropout']) - - @parameterized.named_parameters( - dict( - testcase_name='no_force_decoding', - prompt_with_targets=False, - supports_prefilling=True, - ), - dict( - testcase_name='force_decoding_no_prefill', - prompt_with_targets=True, - supports_prefilling=False, - ), - dict( - testcase_name='force_decoding_prefill_positional', - prompt_with_targets=True, - supports_prefilling=True, - position_embedder=True, - ), - dict( - testcase_name='force_decoding_prefill', - prompt_with_targets=True, - supports_prefilling=True, - position_embedder=False, - ), - ) - def test_prompt_with_targets( - self, prompt_with_targets, supports_prefilling, position_embedder=True - ): - batch_size, encoder_len, max_decode_len, emb_dim = 2, 3, 4, 5 - batch = { - 'encoder_input_tokens': np.zeros( - (batch_size, encoder_len), dtype=np.int32 - ), - 'decoder_input_tokens': np.array( - [[0, 2, 3, 0], [0, 0, 0, 0]], dtype=np.int32 - ), - } - - # These dummy logits represent the probability distribution where all the - # probability mass is in one item (i.e., degenerate distribution). For - # batch element 0, it is vocabulary index 3. - # We test `_predict_step` to avoid having to define a task and its - # vocabulary. - dummy_logits = jnp.expand_dims( - jnp.array([[-1e7, -1e7, -1e7, 0, -1e7], [-1e7, -1e7, -1e7, -1e7, 0]]), - axis=1, - ) - - mock_decode_fn = ( - mock.create_autospec(lambda initial_index, **kwargs: None) - if supports_prefilling - else mock.Mock() - ) - mock_decode_fn.return_value = ( - np.full([batch_size, max_decode_len, 1], 3, dtype=np.int32), - np.full([batch_size, 1], 1.0, dtype=np.float32), - ) - - def mock_module_apply(*args, method=None, **kwargs): - del args, kwargs - if method is None: - return (dummy_logits, {'cache': {}}) - return method() - - mock_init_cache = {'decoder': {}} - if position_embedder: - mock_init_cache['decoder']['position_embedder'] = { - 'position_embedder_index': [0] - } - - mock_module = mock.Mock( - dtype=jnp.float32, - apply=mock.Mock(side_effect=mock_module_apply), - encode=lambda: jnp.zeros((batch_size, encoder_len, emb_dim)), - decode=lambda prefill=None: (dummy_logits, {'cache': mock_init_cache}), - ) - - def mock_init(self): - self.module = mock_module - self.module.scan_layers = False - self._default_decoder_params = models.DecoderParams() - self._input_vocabulary = mock.Mock(eos_id=1) - self._output_vocabulary = mock.Mock(eos_id=1) - self._decode_fn = mock_decode_fn - - with mock.patch.object( - models.EncoderDecoderModel, '__init__', new=mock_init - ): - model = models.EncoderDecoderModel() - - model.predict_batch_with_aux( - {}, batch, prompt_with_targets=prompt_with_targets - ) - - if prompt_with_targets: - expected_inputs = batch['decoder_input_tokens'] - else: - expected_inputs = np.zeros([batch_size, max_decode_len], dtype=np.int32) - - assert mock_decode_fn.call_count == 1 - # Look at the kwargs call list for inputs, assert_called_with doesn't - # work well with np.array comparison. - np.testing.assert_array_equal( - mock_decode_fn.mock_calls[0][2]['inputs'], expected_inputs - ) - if supports_prefilling: - if prompt_with_targets: - expected_initial_index = [2, 0] - else: - expected_initial_index = [0, 0] - np.testing.assert_array_equal( - mock_decode_fn.mock_calls[0][2]['initial_index'], - expected_initial_index, - ) - # Encode inputs, initialize decoder, cache prompt. - self.assertEqual(mock_module.apply.call_count, 3) - if position_embedder: - np.testing.assert_array_equal( - mock_decode_fn.mock_calls[0][2]['cache']['decoder'][ - 'position_embedder' - ]['position_embedder_index'], - expected_initial_index, - ) - else: - self.assertNotIn('initial_index', mock_decode_fn.mock_calls[0][2]) - # Encode inputs, initialize decoder. - self.assertEqual(mock_module.apply.call_count, 2) - - def test_predict_batch_loop_and_caches_are_equal(self): - vocab_size = 50 - lengths = np.array([[2], [3]]) - batch_size, beam_size, encoder_len, max_decode_len = 2, 2, 3, 7 - batch = { - 'encoder_input_tokens': np.zeros( - (batch_size, encoder_len), dtype=np.int32 - ), - 'decoder_target_tokens': np.zeros( - (batch_size, encoder_len), dtype=np.int32 - ), - 'decoder_input_tokens': np.concatenate( - [ - np.expand_dims( - np.concatenate([ - [0], - np.arange(9, 9 + lengths[0][0], dtype=np.int32), - np.zeros( - (max_decode_len - lengths[0][0] - 1), dtype=np.int32 - ), - ]), - axis=0, - ), # First element - np.expand_dims( - np.concatenate([ - [0], - np.arange(3, 3 + lengths[1][0], dtype=np.int32), - np.zeros( - (max_decode_len - lengths[1][0] - 1), dtype=np.int32 - ), - ]), - axis=0, - ), # Second element - ], - axis=0, - ), - } - - model = test_utils.get_t5_test_model(vocab_size=50) - module = model.module - params = module.init( - jax.random.PRNGKey(0), - jnp.ones((batch_size, encoder_len)), - jnp.ones((batch_size, max_decode_len)), - jnp.ones((batch_size, max_decode_len)), - enable_dropout=False, - )['params'] - - def mock_init(self): - self.module = module - self._default_decoder_params = models.DecoderParams() - # Set the EOS token to be larger then the vocabulary size. This forces the - # model to decode all the way to `max_decode_length`, allowing us to test - # behavior when one element reaches the end before the others. - self._output_vocabulary = mock.Mock(eos_id=vocab_size + 12) - self._decode_fn = decoding.beam_search - - with mock.patch.object( - models.EncoderDecoderModel, '__init__', new=mock_init - ): - model = models.EncoderDecoderModel() - - with mock.patch.object( - model, '_compute_logits_from_slice', autospec=True - ) as tokens_to_logits_mock: - # Make the side effect of the mock, call the method on the class, with the - # instance partialed in as `self`. This lets us call the actual code, - # while recording the inputs, without an infinite loop you would get - # calling `instance.method` - tokens_to_logits_mock.side_effect = functools.partial( - models.EncoderDecoderModel._compute_logits_from_slice, model - ) - # Disable jit, so that the `lax.while_loop` isn't traced, as the - # collection of tracers in the mock call_args would generally trigger a - # tracer leak error. - with jax.disable_jit(): - _ = model.predict_batch_with_aux( - params, batch, prompt_with_targets=True, num_decodes=2 - ) - - # Collect all the input tokens to our tokens_to_logits function - all_inputs = [] - all_cache_keys = [] # Collect all the cache keys - all_cache_values = [] # Collect all the cache values - # Currently force decoding generates logits at every step. We should have - # `max_decode_length` calls to our tokens -> logits func. - self.assertLen(tokens_to_logits_mock.call_args_list, max_decode_len) - for tokens_call in tokens_to_logits_mock.call_args_list: - decoding_state: decoding.DecodingState = tokens_call[0][0] - # Inputs: [B, 1] - inputs = decoding_state.cur_token - cache = decoding_state.cache - cache = flax.core.unfreeze(cache) - # Cache: [B * Be, 1] * #Layers - cache_keys = [ - v - for k, v in traverse_util.flatten_dict(cache).items() - if k[-1] == 'cached_key' - ] - cache_values = [ - v - for k, v in traverse_util.flatten_dict(cache).items() - if k[-1] == 'cached_value' - ] - all_inputs.append(inputs) - all_cache_keys.append(cache_keys) - all_cache_values.append(cache_values) - # Convert inputs to a single block [B, DL, Be] - all_inputs = np.concatenate(all_inputs, axis=1) - # Convert caches into a single block per layer [B * Be, DL] * L - all_cache_keys = [np.stack(c, axis=1) for c in zip(*all_cache_keys)] - all_cache_values = [np.stack(c, axis=1) for c in zip(*all_cache_values)] - - # Make sure that for each batch, the cache for each beam is identical when - # prompt is being forced. - for b in range(batch_size): - for i, input_token in enumerate(all_inputs[b * beam_size]): - if i < lengths[b]: - self.assertEqual(input_token, batch['decoder_input_tokens'][b][i]) - # For all layers. - for cache_keys in all_cache_keys: - np.testing.assert_array_equal( - cache_keys[b * beam_size][i], cache_keys[b * beam_size + 1][i] - ) - for cache_values in all_cache_values: - np.testing.assert_array_equal( - cache_values[b * beam_size][i], - cache_values[b * beam_size + 1][i], - ) - - def test_score_batch(self): - encoder_input_tokens = jnp.ones((2, 3)) - # For this test, decoder input and target tokens are dummy values. - decoder_input_tokens = jnp.array([[1, 2, 1, 0], [0, 1, 0, 2]]) - decoder_target_tokens = jnp.array([[1, 2, 1, 0], [0, 1, 0, 2]]) - decoder_loss_weights = jnp.array([[1, 1, 1, 0], [0, 1, 0, 1]]) - logits = jnp.arange(0, 24).reshape((2, 4, 3)) - params = {'foo': jnp.zeros(3)} - - mock_transformer = mock.Mock() - mock_transformer.apply.return_value = logits - mock_transformer.dtype = jnp.float32 - - batch = { - 'encoder_input_tokens': encoder_input_tokens, - 'decoder_input_tokens': decoder_input_tokens, - 'decoder_target_tokens': decoder_target_tokens, - 'decoder_loss_weights': decoder_loss_weights, - } - - def mock_init(self): - self.module = mock_transformer - - with mock.patch.object( - models.EncoderDecoderModel, '__init__', new=mock_init - ): - model = models.EncoderDecoderModel() - res = model.score_batch(params, batch) - - mock_transformer.apply.assert_called_with( - {'params': params}, - encoder_input_tokens, - decoder_input_tokens, - decoder_target_tokens, - encoder_segment_ids=None, - decoder_segment_ids=None, - encoder_positions=None, - decoder_positions=None, - decode=False, - enable_dropout=False, - rngs=None, - mutable=False, - ) - np.testing.assert_allclose(res, [-3.222973, -1.815315], rtol=1e-4) - - def test_score_batch_can_return_intermediates(self): - encoder_input_tokens = jnp.ones((2, 3)) - # For this test, decoder input and target tokens are dummy values. - decoder_input_tokens = jnp.array([[1, 2, 1, 0], [0, 1, 0, 2]]) - decoder_target_tokens = jnp.array([[1, 2, 1, 0], [0, 1, 0, 2]]) - decoder_loss_weights = jnp.array([[1, 1, 1, 0], [0, 1, 0, 1]]) - logits = jnp.arange(0, 24).reshape((2, 4, 3)) - modified_variables = {'intermediates': {'bar': jnp.ones(5)}} - params = {'foo': jnp.zeros(3)} - - mock_transformer = mock.Mock() - mock_transformer.apply.return_value = (logits, modified_variables) - mock_transformer.dtype = jnp.float32 - - batch = { - 'encoder_input_tokens': encoder_input_tokens, - 'decoder_input_tokens': decoder_input_tokens, - 'decoder_target_tokens': decoder_target_tokens, - 'decoder_loss_weights': decoder_loss_weights, - } - - def mock_init(self): - self.module = mock_transformer - - with mock.patch.object( - models.EncoderDecoderModel, '__init__', new=mock_init - ): - model = models.EncoderDecoderModel() - scores, intermediates = model.score_batch( - params, batch, return_intermediates=True - ) - - mock_transformer.apply.assert_called_with( - {'params': params}, - encoder_input_tokens, - decoder_input_tokens, - decoder_target_tokens, - encoder_segment_ids=None, - decoder_segment_ids=None, - encoder_positions=None, - decoder_positions=None, - decode=False, - enable_dropout=False, - rngs=None, - mutable=['intermediates'], - ) - np.testing.assert_allclose(scores, [-3.222973, -1.815315], rtol=1e-4) - # Incumbent intermediates are passed out unchanged. - np.testing.assert_allclose(intermediates['bar'], jnp.ones(5)) - # A new collection of decoder intermediates are inserted by score_batch() - np.testing.assert_allclose( - intermediates['decoder']['loss_weights'][0], decoder_loss_weights - ) - np.testing.assert_allclose( - intermediates['decoder']['target_tokens'][0], decoder_target_tokens - ) - - def test_train_transformer_wmt(self): - # Dummy input data - input_shape = (16, 8) - encoder_input_tokens = np.ones(shape=input_shape, dtype=np.float32) - decoder_input_tokens = 5 * np.ones(shape=input_shape, dtype=np.float32) - decoder_target_tokens = 5 * np.ones(input_shape, dtype=np.float32) - # input_data = {'inputs': inputs, 'targets': targets} - input_data = { - 'encoder_input_tokens': encoder_input_tokens, - 'decoder_input_tokens': decoder_input_tokens, - 'decoder_target_tokens': decoder_target_tokens, - } - - partitioner = partitioning.PjitPartitioner(num_partitions=1) - - model = test_utils.get_t5_test_model() - - ds_iter = tf.data.Dataset.from_tensors(input_data).as_numpy_iterator() - input_shapes = {k: input_shape for k in input_data} - - train_state_initializer = utils.TrainStateInitializer( - optimizer_def=model.optimizer_def, - init_fn=model.get_initial_variables, - input_shapes=input_shapes, - partitioner=partitioner, - ) - train_state_axes = train_state_initializer.train_state_axes - - trainer = trainer_lib.Trainer( - model, - train_state=train_state_initializer.from_scratch(jax.random.PRNGKey(0)), - partitioner=partitioner, - eval_names=[], - summary_dir=None, - train_state_axes=train_state_axes, - rng=jax.random.PRNGKey(0), - learning_rate_fn=lambda x: 0.001, - num_microbatches=1, - ) - - trainer.train(ds_iter, 1) - logging.info('optimizer after first step %s', trainer.train_state.params) - - - @parameterized.parameters( - {'decode_fn': decoding.beam_search}, - {'decode_fn': functools.partial(decoding.temperature_sample, topk=4)}, - ) - def test_predict_batch(self, decode_fn): - batch_size, encoder_len, max_decode_len, emb_dim = 2, 3, 4, 5 - batch = { - 'encoder_input_tokens': np.zeros( - (batch_size, encoder_len), dtype=np.int32 - ), - 'decoder_input_tokens': np.zeros( - (batch_size, max_decode_len), dtype=np.int32 - ), - } - - # These dummy logits represent the probability distribution where all the - # probability mass is in one item (i.e., degenerate distribution). For - # batch element 0, it is vocabulary index 2. - # We test `_predict_step` to avoid having to define a task and its - # vocabulary. - dummy_logits = jnp.expand_dims( - jnp.array([[-1e7, -1e7, 0, -1e7], [-1e7, -1e7, -1e7, 0]]), axis=1 - ) - - class MockModule: - - def __init__(self): - self.dtype = jnp.float32 - - def apply(self, *args, method=None, **kwargs): - del args, kwargs - if method is None: # use for module.`__call__` - return (dummy_logits, {'cache': {}}) - else: - return method() - - def encode(self): - return jnp.zeros((batch_size, encoder_len, emb_dim)) - - def decode(self): - return (dummy_logits, {'cache': {}}) - - def mock_init(self): - self.module = MockModule() - self.module.scan_layers = False - self._default_decoder_params = models.DecoderParams() - self._input_vocabulary = mock.Mock(eos_id=1) - self._output_vocabulary = mock.Mock(eos_id=1) - self._decode_fn = decode_fn - - with mock.patch.object( - models.EncoderDecoderModel, '__init__', new=mock_init - ): - model = models.EncoderDecoderModel() - - actual = model.predict_batch({}, batch) - # The predicted token for the first batch element is always 2 and it is 3 - # for the second batch element. - expected = [[2] * max_decode_len, [3] * max_decode_len] - np.testing.assert_array_equal(actual, expected) - - def test_predict_batch_rng(self): - batch = { - 'encoder_input_tokens': np.zeros((2, 1), dtype=np.int32), - 'decoder_input_tokens': np.zeros((2, 2), dtype=np.int32), - } - - decode_fn_mock = mock.Mock( - return_value=(np.zeros((2, 2, 3)), np.zeros((2, 2))) - ) - - def mock_init(self): - self.module = mock.Mock( - apply=mock.Mock( - side_effect=lambda *_, **kwargs: ( # pylint:disable=g-long-lambda,g-long-ternary - np.zeros((2, 2)), - {'cache': None}, - ) - if 'mutable' in kwargs - else np.zeros((2, 2)) - ) - ) - self._output_vocabulary = mock.Mock(eos_id=1) - self._decode_fn = decode_fn_mock - self._default_decoder_params = models.DecoderParams() - - with mock.patch.object( - models.EncoderDecoderModel, '__init__', new=mock_init - ): - model = models.EncoderDecoderModel() - - # No RNG - model.predict_batch({}, batch) - _, decode_fn_kwargs = decode_fn_mock.call_args - self.assertNotIn('decode_rng', decode_fn_kwargs) - - # No RNG (w/ aux) - model.predict_batch_with_aux({}, batch) - _, decode_fn_kwargs = decode_fn_mock.call_args - self.assertNotIn('decode_rng', decode_fn_kwargs) - - # decoder_params RNG - model.predict_batch_with_aux({}, batch, decoder_params={'decode_rng': 3}) - _, decode_fn_kwargs = decode_fn_mock.call_args - self.assertEqual(decode_fn_kwargs['decode_rng'], 3) - - # rng RNG - model.predict_batch({}, batch, rng=4) - _, decode_fn_kwargs = decode_fn_mock.call_args - self.assertEqual(decode_fn_kwargs['decode_rng'], 4) - - # rng RNG (w/ aux) - model.predict_batch_with_aux({}, batch, rng=4) - _, decode_fn_kwargs = decode_fn_mock.call_args - self.assertEqual(decode_fn_kwargs['decode_rng'], 4) - - # Both - with self.assertRaisesWithLiteralMatch( - ValueError, - 'Got RNG both from the `rng` argument (4) and ' - "`decoder_params['decode_rng']` (3). Please specify one or the other.", - ): - model.predict_batch_with_aux( - {}, batch, rng=4, decoder_params={'decode_rng': 3} - ) - - @parameterized.named_parameters( - dict( - testcase_name='int32', - batch={ - 'encoder_input_tokens': np.zeros( - (BATCH_SIZE, ENCODER_LEN), dtype=np.int32 - ), - 'decoder_input_tokens': np.zeros( - (BATCH_SIZE, MAX_DECODE_LEN), dtype=np.int32 - ), - }, - ), - dict( - testcase_name='float32', - batch={ - 'encoder_input_tokens': np.zeros( - (BATCH_SIZE, ENCODER_LEN), dtype=np.float32 - ), - 'decoder_input_tokens': np.zeros( - (BATCH_SIZE, MAX_DECODE_LEN), dtype=np.float32 - ), - }, - ), - ) - def test_predict_batch_fake_input_shapes_and_types(self, batch): - - # These dummy logits represent the probability distribution where all the - # probability mass is in one item (i.e., degenerate distribution). For - # batch element 0, it is vocabulary index 2. - # We test `_predict_step` to avoid having to define a task and its - # vocabulary. - dummy_logits = jnp.ones((2, 1, 4), jnp.float32) - - class MockModule: - - def __init__(self): - self.dtype = jnp.float32 - self.call_args_list = [] - - def apply(self, *args, method=None, **kwargs): - # Not sure why this isn't a real Mock so just record the args/kwargs - self.call_args_list.append({'args': args, 'kwargs': kwargs}) - del args, kwargs - if method is None: # use for module.`__call__` - return (dummy_logits, {'cache': {}}) - else: - return method() - - def encode(self): - return jnp.zeros((BATCH_SIZE, ENCODER_LEN, EMBED_DIM)) - - def decode(self): - return (dummy_logits, {'cache': {}}) - - def mock_init(self): - self.module = MockModule() - self.module.scan_layers = False - self._default_decoder_params = models.DecoderParams() - self._input_vocabulary = mock.Mock(eos_id=1) - self._output_vocabulary = mock.Mock(eos_id=1) - self._decode_fn = decoding.beam_search - self._inputs_bidirectional_attention = False - - with mock.patch.object( - models.EncoderDecoderModel, '__init__', new=mock_init - ): - model = models.EncoderDecoderModel() - model.predict_batch({}, batch) - - fake_inputs = jnp.ones_like(batch['encoder_input_tokens']) - fake_target = jnp.ones_like(batch['decoder_input_tokens']) - - cache_init_call = model.module.call_args_list[1] - self.assertEqual(cache_init_call['args'][0], {'params': {}}) - call_kwargs = cache_init_call['kwargs'] - np.testing.assert_allclose( - call_kwargs.pop('encoder_input_tokens'), fake_inputs - ) - np.testing.assert_allclose( - call_kwargs.pop('decoder_input_tokens'), fake_target - ) - np.testing.assert_allclose( - call_kwargs.pop('decoder_target_tokens'), fake_target - ) - self.assertEqual( - call_kwargs, - { - 'decode': True, - 'enable_dropout': False, - 'mutable': ['cache'], - }, - ) - - -class DecoderOnlyModelTest(parameterized.TestCase): - - - - def test_predict_batch_visible_in_prefill(self): - batch_size = 2 - seq_len = 10 - lengths = np.array([[6], [3]]) - batch = { - 'decoder_input_tokens': np.tile( - np.expand_dims(np.arange(seq_len, dtype=np.int32), axis=0), - (batch_size, 1), - ), - 'decoder_causal_attention': (lengths > np.arange(seq_len)).astype( - np.int32 - ), - } - - dummy_logits = jnp.expand_dims( - jnp.array([[-1e7, -1e7, 0, -1e7], [-1e7, -1e7, -1e7, 0]]), axis=1 - ) - - mock_module = mock.Mock() - mock_module.apply.return_value = (dummy_logits, {'cache': {}}) - mock_module.dtype = jnp.float32 - - def mock_init(self): - self.module = mock_module - self._output_vocabulary = mock.Mock(eos_id=1) - self._decode_fn = functools.partial(decoding.temperature_sample, topk=4) - self._inputs_bidirectional_attention = False - - with mock.patch.object(models.DecoderOnlyModel, '__init__', new=mock_init): - model = models.DecoderOnlyModel() - - model.predict_batch({}, batch) - prefill_call = mock_module.apply.call_args_list[1] - kwargs = prefill_call[1] - inputs = prefill_call[1]['decoder_input_tokens'] - # Note that, for the prefill call, we use 'decoder_causal_attention' as - # 'decoder_target_tokens'. - targets = prefill_call[1]['decoder_target_tokens'] - self.assertTrue(kwargs['prefill']) - np.testing.assert_array_equal( - kwargs['prefill_lengths'], np.squeeze(lengths - 1, axis=-1) - ) - # Test that the non padding values of the "targets" cover all of the input, - # you it will all be considered in the attention mask. - np.testing.assert_array_equal(inputs * targets, inputs) - # Check that the first value of the target is 1, the first value of the - # inputs is always 0 so the masking check wouldn't catch it if the target - # had a 0 in the first location. - np.testing.assert_array_equal(targets[:, 0], np.ones_like(targets[:, 0])) - # Test that the targets are properly removed. Our input is a sequence from 0 - # onward, so our largest value (the last input) should be equal by its - # position (which is 1 - length). If we didn't mask the target correctly, - # we would expect a larger value in the max. - np.testing.assert_array_equal( - np.max(inputs, axis=1), np.squeeze(lengths - 1, axis=-1) - ) - - - def test_predict_batch(self): - batch = { - 'decoder_input_tokens': np.array( - [[0, 3, 4, 5, 6, 0, 0], [0, 7, 8, 9, 0, 0, 0]] - ), - 'decoder_causal_attention': np.array( - [[1, 1, 1, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0, 0]] - ), - } - - # These dummy logits represent the probability distribution where all the - # probability mass is in one item (i.e., degenerate distribution). For - # batch element 0, it is vocabulary index 2. - # We test `_predict_step` to avoid having to define a task and its - # vocabulary. - dummy_logits = jnp.expand_dims( - jnp.array([[-1e7, -1e7, 0, -1e7], [-1e7, -1e7, -1e7, 0]]), axis=1 - ) - - mock_module = mock.Mock() - mock_module.apply.return_value = (dummy_logits, {'cache': {}}) - mock_module.dtype = jnp.float32 - - def mock_init(self): - self.module = mock_module - self._output_vocabulary = mock.Mock(eos_id=1) - self._decode_fn = functools.partial(decoding.temperature_sample, topk=4) - self._inputs_bidirectional_attention = False - - with mock.patch.object(models.DecoderOnlyModel, '__init__', new=mock_init): - model = models.DecoderOnlyModel() - - actual = model.predict_batch({}, batch) - - expected = [[2, 2, 2, 2, 2, 0, 0], [3, 3, 3, 3, 3, 3, 0]] - - # The expected progression of the first element of 'decoder_input_tokens': - # [0, 3, 4, 5, 6, 0, 0] -> [0, 3, 4, 0, 0, 0, 0] -> - # [3, 4, 2, 2, 2, 2, 2] -> [2, 2, 2, 2, 2, 0, 0] - - # The expected progression of the second element of 'decoder_input_tokens': - # [0, 7, 8, 9, 0, 0, 0] -> [0, 7, 0, 0, 0, 0, 0] -> - # [7, 3, 3, 3, 3, 3, 3] -> [3, 3, 3, 3, 3, 3, 0] - - np.testing.assert_array_equal(actual, expected) - - def test_predict_batch_rng(self): - batch = { - 'decoder_input_tokens': np.zeros((2, 2), dtype=np.int32), - 'decoder_causal_attention': np.zeros((2, 2), dtype=np.int32), - } - - decode_fn_mock = mock.Mock( - return_value=(np.zeros((2, 2, 3)), np.zeros((2, 2))) - ) - - def mock_init(self): - self.module = mock.Mock( - apply=mock.Mock( - side_effect=lambda *_, **kwargs: ( # pylint:disable=g-long-lambda,g-long-ternary - np.zeros((2, 2)), - {'cache': None}, - ) - if 'mutable' in kwargs - else np.zeros((2, 2)) - ) - ) - self._output_vocabulary = mock.Mock(eos_id=1) - self._decode_fn = decode_fn_mock - self._inputs_bidirectional_attention = False - - with mock.patch.object(models.DecoderOnlyModel, '__init__', new=mock_init): - model = models.DecoderOnlyModel() - - # No RNG - model.predict_batch({}, batch) - _, decode_fn_kwargs = decode_fn_mock.call_args - self.assertNotIn('decode_rng', decode_fn_kwargs) - - # No RNG (w/ aux) - model.predict_batch_with_aux({}, batch) - _, decode_fn_kwargs = decode_fn_mock.call_args - self.assertNotIn('decode_rng', decode_fn_kwargs) - - # decoder_params RNG - model.predict_batch_with_aux({}, batch, decoder_params={'decode_rng': 3}) - _, decode_fn_kwargs = decode_fn_mock.call_args - self.assertEqual(decode_fn_kwargs['decode_rng'], 3) - - # rng RNG - model.predict_batch({}, batch, rng=4) - _, decode_fn_kwargs = decode_fn_mock.call_args - self.assertEqual(decode_fn_kwargs['decode_rng'], 4) - - # rng RNG (w/ aux) - model.predict_batch_with_aux({}, batch, rng=4) - _, decode_fn_kwargs = decode_fn_mock.call_args - self.assertEqual(decode_fn_kwargs['decode_rng'], 4) - - # Both - with self.assertRaisesWithLiteralMatch( - ValueError, - 'Got RNG both from the `rng` argument (4) and ' - "`decoder_params['decode_rng']` (3). Please specify one or the other.", - ): - model.predict_batch_with_aux( - {}, batch, rng=4, decoder_params={'decode_rng': 3} - ) - - def test_predict_batch_num_decodes_temperature_sample(self): - batch = { - 'decoder_input_tokens': np.array([ - [0, 3, 4, 5, 6, 0, 0], - ]), - 'decoder_causal_attention': np.array([ - [1, 1, 1, 0, 0, 0, 0], - ]), - } - - # These dummy logits represent the probability distribution where all the - # probability mass is in one item (i.e., degenerate distribution). For - # batch element 0, it is vocabulary index 2. We have two samples. - # Technically these should be identical since the prompts are the same, but - # this makes testing easier. - dummy_logits = jnp.expand_dims( - jnp.array([[-1e7, -1e7, 0, -1e7], [-1e7, -1e7, -1e7, 0]]), axis=1 - ) - - mock_module = mock.Mock() - mock_module.apply.return_value = (dummy_logits, {'cache': {}}) - mock_module.dtype = jnp.float32 - - def mock_init(self): - self.module = mock_module - self._output_vocabulary = mock.Mock(eos_id=1) - self._decode_fn = functools.partial(decoding.temperature_sample, topk=4) - self._inputs_bidirectional_attention = False - - with mock.patch.object(models.DecoderOnlyModel, '__init__', new=mock_init): - model = models.DecoderOnlyModel() - - actual_output, aux = model.predict_batch_with_aux( - {}, batch, num_decodes=2, return_all_decodes=True - ) - - expected_output = [[[2, 2, 2, 2, 2, 0, 0], [3, 3, 3, 3, 3, 0, 0]]] - expected_scores = [[0.0, 0.0]] - - # The expected progression of the first element of 'decoder_input_tokens': - # [0, 3, 4, 5, 6, 0, 0] -> [0, 3, 4, 0, 0, 0, 0] -> - # [3, 4, 2, 2, 2, 2, 2] -> [2, 2, 2, 2, 2, 0, 0] - - # The expected progression of the second element of 'decoder_input_tokens': - # [0, 7, 8, 9, 0, 0, 0] -> [0, 7, 0, 0, 0, 0, 0] -> - # [7, 3, 3, 3, 3, 3, 3] -> [3, 3, 3, 3, 3, 3, 0] - - np.testing.assert_array_equal(actual_output, expected_output) - np.testing.assert_array_equal(aux['scores'], expected_scores) - - def test_predict_batch_fake_input_shapes_and_types(self): - # The input and causal attention actually have to be int32 for this test, - # even though the cache init should work with any types the `inputs` that - # is created from multiplying the causal attention and the input tokens - # needs to be an int or the decoding will fail. - batch = { - 'decoder_input_tokens': np.array( - [[0, 3, 4, 5, 6, 0, 0], [0, 7, 8, 9, 0, 0, 0]], dtype=np.int32 - ), - 'decoder_causal_attention': np.array( - [[1, 1, 1, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0, 0]], dtype=np.int32 - ), - } - - dummy_logits = jnp.ones((2, 1, 5), jnp.float32) - - mock_module = mock.Mock() - mock_module.apply.return_value = (dummy_logits, {'cache': {}}) - mock_module.dtype = jnp.float32 - - def mock_init(self): - self.module = mock_module - self._output_vocabulary = mock.Mock(eos_id=1) - self._decode_fn = functools.partial(decoding.temperature_sample, topk=4) - self._inputs_bidirectional_attention = False - - with mock.patch.object(models.DecoderOnlyModel, '__init__', new=mock_init): - model = models.DecoderOnlyModel() - - model.predict_batch({}, batch) - - fake_target = jnp.ones_like(batch['decoder_input_tokens']) - - cache_init_call = mock_module.apply.call_args_list[0] - - self.assertEqual(cache_init_call[0][0], {'params': {}}) - np.testing.assert_allclose(cache_init_call[0][1], fake_target) - np.testing.assert_allclose(cache_init_call[0][2], fake_target) - self.assertEqual( - cache_init_call[1], - {'decode': True, 'enable_dropout': False, 'mutable': ['cache']}, - ) - - @parameterized.named_parameters( - dict( - testcase_name='no_types', - shapes={'decoder_input_tokens': [1, 62]}, - types=None, - ), - dict( - testcase_name='int32', - shapes={'decoder_input_tokens': [1, 62]}, - types={'decoder_input_tokens': jnp.int32}, - ), - dict( - testcase_name='float32', - shapes={'decoder_input_tokens': [1, 62]}, - types={'decoder_input_tokens': jnp.int32}, - ), - ) - def test_get_initial_variables_shapes_and_types(self, shapes, types): - mock_lm = mock.Mock() - mock_lm.init.return_value = {'params': {}} - mock_optimizer_def = mock.Mock() - rng = mock.Mock() - - def mock_init(self): - self.module = mock_lm - self.optimizer_def = mock_optimizer_def - - with mock.patch.object(models.DecoderOnlyModel, '__init__', new=mock_init): - model = models.DecoderOnlyModel() - model.get_initial_variables(rng, shapes, types) - - if types is None: - decoder_input = jnp.ones( - shapes['decoder_input_tokens'], dtype=jnp.float32 - ) - else: - decoder_input = jnp.ones( - shapes['decoder_input_tokens'], dtype=types['decoder_input_tokens'] - ) - - # Using `.assert_called_once_with` doesn't work because the simple - # comparison it does for the array arguments fail (truth value of an array - # is ambiguous). - called_with = mock_lm.init.call_args - self.assertEqual(called_with[0][0], rng) - np.testing.assert_allclose(called_with[0][1], decoder_input) - np.testing.assert_allclose(called_with[0][2], decoder_input) - self.assertEqual(mock_lm.init.call_args[1], {'enable_dropout': False}) - - -if __name__ == '__main__': - absltest.main() diff --git a/t5x-main/t5x/notebooks/README.md b/t5x-main/t5x/notebooks/README.md deleted file mode 100644 index cb047546fec1abe9b1bd7afa808182941967d44a..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/notebooks/README.md +++ /dev/null @@ -1,76 +0,0 @@ -# Prepare Colab Runtime - -Currently the [default public Colab](https://colab.research.google.com/) cannot be easily used to run T5X models. Here we provide an alternative, i.e., creating a custom jupyter kernel/runtime via Google Cloud TPU VM. One can then use the `Connect to a local runtime` option run the notebooks in this folder. - -## Create TPU VM -You should follow T5X's main README.md [installation guide](https://github.com/google-research/t5x#installation) to setup a GCP account. - -Then create a TPU VM via the command below (make sure to change `TPUVMNAME` and `TPUVMZONE` accordingly) - -``` -export TPUVMNAME=xxxx; -export TPUVMZONE=xxxxxxx; -export TPUTYPE=v3-8; -export APIVERSION=v2-alpha - -gcloud alpha compute tpus tpu-vm create ${TPUVMNAME} --zone=${TPUVMZONE} --accelerator-type=${TPUTYPE} --version=${APIVERSION} -``` - -## ssh to TPU VM -You need to set proper firewall rules to be able to ssh into the VM. - -``` -gcloud compute firewall-rules create default-allow-ssh --allow tcp:22 -``` - -ssh into the VM with port forwarding (`8888` is often used for ipython notebook kernel) - -``` -gcloud compute tpus tpu-vm ssh ${TPUVMNAME} --zone=${TPUVMZONE} -- -L 8888:localhost:8888 -``` - -## Prepare python env - -Create a python environment via - -``` -sudo apt update -sudo apt install -y python3.9 python3.9-venv -python3.9 -m venv t5_venv -``` - -Then install T5X with its dependencies. - -``` -source t5_venv/bin/activate -python3 -m pip install -U pip setuptools wheel ipython -pip install flax -git clone --branch=main https://github.com/google-research/t5x -cd t5x -python3 -m pip install -e '.[tpu]' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html -cd - -``` - -After this, we can test if we can accessed TPU successfully by (should print out a list of TPU devices) - -``` -python3 -c "import jax; print(jax.local_devices())" -``` - -At last, we prepare necessary packages to allow the jupyter kernel can be access by our colab notebooks. - -``` -pip install notebook -pip install --upgrade jupyter_http_over_ws>=0.0.7 -jupyter serverextension enable --py jupyter_http_over_ws -``` - -## Launch runtime - -Use the command below to launch the prepared runtime. - -``` -jupyter notebook --NotebookApp.allow_origin='https://colab.research.google.com' --port=8888 --NotebookApp.port_retries=0 -``` - -from the log of the above command, you can see an http link starting with `http://localhost:8888/?token`s. Copy and paste it into the `Connect to a local runtime` option and now you should be able to run T5X colab notebooks. diff --git a/t5x-main/t5x/notebooks/evaluation.ipynb b/t5x-main/t5x/notebooks/evaluation.ipynb deleted file mode 100644 index ebb9add076ee6cd7cfa08c5ae0d185963ba66850..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/notebooks/evaluation.ipynb +++ /dev/null @@ -1,1019 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "-0BQWhvAP2jb" - }, - "source": [ - "\n", - "\u003ca href=\"https://colab.research.google.com/github/google-research/t5x/blob/main/t5x/notebooks/evaluation.ipynb\" target=\"_parent\"\u003e\u003cimg src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/\u003e\u003c/a\u003e" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "bqZYp90PIa1t" - }, - "source": [ - "# Overview\n", - "\n", - "This is the fourth Colab in a [series of tutorials on how to use T5X](https://github.com/google-research/t5x/blob/main/docs/tutorials.md). We assume that you have already completed the [Introductory Colab](https://colab.research.google.com/github/google-research/t5x/blob/main/t5x/notebooks/introduction.ipynb), the [Training Deep Dive](https://colab.research.google.com/github/google-research/t5x/blob/main/t5x/notebooks/training.ipynb), and the [Inference Deep Dive](https://colab.research.google.com/github/google-research/t5x/blob/main/t5x/notebooks/inference.ipynb), or have a basic understanding of the T5X models, checkpoints, partitioner, trainer, and `InteractiveModel`.\n", - "\n", - "In the [previous Colab](https://colab.research.google.com/github/google-research/t5x/blob/main/t5x/notebooks/inference.ipynb) in this tutorial series, we dove into how the `InteractiveModel` does decoding to generate predictions and scores for a given input. We will now focus on how the InteractiveModel takes a batch of inputs and targets and runs evaluation to produce various metrics. It should be noted that the code snippets below exactly replicate the InteractiveModel `__init__()` and `evaluate()` methods (see [source code](https://github.com/google-research/t5x/blob/main/t5x/interactive_model.py)); we expose this functionality here in order to demonstrate how various components of the T5X codebase work together to perform evaluation on a model." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "nZJbWZcfkyxI" - }, - "source": [ - "# Set-Up" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "VkVjJOewsMM8" - }, - "source": [ - "Note: If you are a using public colab, please use its `Connect to a local runtime` option by following the [setup guide](https://github.com/google-research/t5x/blob/main/t5x/notebooks/README.md)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "jIGSIHzD7YPO" - }, - "outputs": [], - "source": [ - "from collections.abc import Sequence\n", - "import enum\n", - "import functools\n", - "import inspect\n", - "import itertools\n", - "import logging\n", - "import os\n", - "import re\n", - "from typing import Any, Callable, Iterator, Optional, Tuple, Union\n", - "\n", - "import jax\n", - "from jax import random\n", - "from jax.experimental import multihost_utils\n", - "import numpy as np\n", - "import seqio\n", - "import tensorflow as tf\n", - "import tensorflow_datasets as tfds\n", - "from t5.evaluation import metrics as t5_metrics\n", - "import t5.data" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "Mt7gxVc9sVjN" - }, - "outputs": [], - "source": [ - "import clu.data\n", - "from t5x.examples.t5 import network\n", - "import t5x\n", - "from t5x import models\n", - "from t5x import partitioning\n", - "from t5x import trainer as trainer_lib\n", - "from t5x import utils\n", - "from t5x.infer import _extract_tokens_and_aux_values\n", - "from t5x.infer import _Inferences\n", - "from t5x.interactive_model import InteractiveModel\n", - "from t5x.interactive_model import get_batches_from_seqio\n", - "from t5x.interactive_model import get_dataset_from_natural_text_examples\n", - "from t5x.interactive_model import get_gin_config_from_interactive_model\n", - "from t5x.interactive_model import T5XScriptType\n", - "from t5x.interactive_model import InferenceType" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "S5Lb-Z1fkF5a" - }, - "source": [ - "Before we begin, let's initialize instances of the constructor arguments for the `InteractiveModel`. As mentioned previously, this will enable us to dive into how the `InteractiveModel` runs inference.\n", - "\n", - "If you don't understand the lines of code below, or have questions about how to initialize these parameters, please see the [first Colab](https://colab.research.google.com/github/google-research/t5x/blob/main/t5x/notebooks/introduction.ipynb) in this tutorial series." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "Ne8U8qoWkX_r" - }, - "outputs": [], - "source": [ - "# Define a model. The configuration below corresponds to the T5 1.1 Small model.\n", - "t5_config = network.T5Config(\n", - " vocab_size=32128,\n", - " dtype='bfloat16',\n", - " emb_dim=512,\n", - " num_heads=6,\n", - " num_encoder_layers=8,\n", - " num_decoder_layers=8,\n", - " head_dim=64,\n", - " mlp_dim=1024,\n", - " mlp_activations=('gelu', 'linear'),\n", - " dropout_rate=0.0,\n", - " logits_via_embedding=False)\n", - "module = network.Transformer(config=t5_config)\n", - "model = t5x.models.EncoderDecoderModel(\n", - " module=module,\n", - " input_vocabulary=t5.data.get_default_vocabulary(),\n", - " output_vocabulary=t5.data.get_default_vocabulary(),\n", - " optimizer_def=t5x.adafactor.Adafactor(decay_rate=0.8, step_offset=0))\n", - "# Define checkpoint arguments.\n", - "checkpoint_path='gs://t5-data/pretrained_models/cbqa/small_ssm_nq/model.ckpt-1110000'\n", - "dtype='bfloat16'\n", - "restore_mode='specific'\n", - "# Define a partitioner.\n", - "partitioner=partitioning.PjitPartitioner(num_partitions=2)\n", - "# Define additional, miscellaneous constructor arguments.\n", - "batch_size=8\n", - "task_feature_lengths = {'inputs': 38, 'targets': 18}\n", - "output_dir='/tmp/output_dir'\n", - "input_shapes = {\n", - " 'encoder_input_tokens': np.array([8, 38]),\n", - " 'decoder_target_tokens': np.array([8, 18]),\n", - " 'decoder_input_tokens': np.array([8, 18]),\n", - " 'decoder_loss_weights': np.array([8, 18])\n", - "}" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "EYwdg-fFTU8Q" - }, - "source": [ - "In addition, we will run all code that is performed when we initialize the InteractiveModel. If you don't understand the lines of code below or have any additional questions about how/why we do the steps below, please see the [second Colab](https://colab.research.google.com/github/google-research/t5x/blob/main/t5x/notebooks/training.ipynb) in our tutorial series." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "YmGTJBAcTpMR" - }, - "outputs": [], - "source": [ - "# 1.) Configure the Output Directory\n", - "output_dir = re.sub(r\"(?\u003c!gs:)([\\/]{2,})\", \"/\", output_dir)\n", - "if not os.path.exists(output_dir):\n", - " os.mkdir(output_dir)\n", - "\n", - "# 2.) Initialize RNGs\n", - "init_random_seed = 42\n", - "random_seed = multihost_utils.broadcast_one_to_all(np.int32(init_random_seed))\n", - "utils.set_hardware_rng_ops()\n", - "rng = random.PRNGKey(random_seed)\n", - "init_rng, trainer_rng = random.split(rng, 2)\n", - "\n", - "# 3.) Validate the Partitioner\n", - "if partitioner._model_parallel_submesh:\n", - " num_partitions = np.prod(partitioner._model_parallel_submesh)\n", - "else:\n", - " num_partitions = partitioner._num_partitions\n", - "if jax.device_count() % num_partitions != 0:\n", - " raise ValueError(\n", - " \"The number of devices available must be a multiple of the number of\",\n", - " f\" partitions. There are {jax.device_count()} devices available, but\",\n", - " f\" the number of partitions is set to {num_partitions}. Please\",\n", - " \" provide a different number of partitions.\")\n", - "\n", - "# 4.) Create a Checkpoint Manager\n", - "# a.) Define CheckpointCfg wrappers.\n", - "save_checkpoint_cfg = utils.SaveCheckpointConfig(\n", - " dtype=dtype,\n", - " keep=5, # The number of checkpoints to keep in the output_dir.\n", - " save_dataset=False)\n", - "restore_checkpoint_cfg = utils.RestoreCheckpointConfig(\n", - " dtype=dtype,\n", - " mode=restore_mode,\n", - " path=checkpoint_path)\n", - "\n", - "# b.) Define a train state initializer, which will help us get information about the\n", - "# TrainState shape.\n", - "train_state_initializer = utils.TrainStateInitializer(\n", - " optimizer_def=model.optimizer_def,\n", - " init_fn=model.get_initial_variables,\n", - " input_shapes=input_shapes,\n", - " input_types=None,\n", - " partitioner=partitioner)\n", - "\n", - "# c.) Define the checkpoint manager.\n", - "checkpoint_manager = utils.LegacyCheckpointManager(\n", - " save_cfg=save_checkpoint_cfg,\n", - " restore_cfg=restore_checkpoint_cfg,\n", - " train_state_shape=train_state_initializer.global_train_state_shape,\n", - " partitioner=partitioner,\n", - " ds_iter=None,\n", - " model_dir=output_dir)\n", - "\n", - "### 5.) Restore the Model from a Checkpoint, or Initialize from Scratch ###\n", - "def get_state(rng):\n", - " return train_state_initializer.from_scratch(rng).state_dict()\n", - "\n", - "# a.) Try to restore a model from a checkpoint.\n", - "train_state = checkpoint_manager.restore(\n", - " [restore_checkpoint_cfg.path],\n", - " restore_checkpoint_cfg,\n", - " utils.get_fallback_state(restore_checkpoint_cfg, get_state, init_rng)\n", - ")\n", - "\n", - "# b.) If no checkpoint to restore, init from scratch.\n", - "if train_state is None:\n", - " train_state = train_state_initializer.from_scratch(init_rng)\n", - "\n", - "output_features = {\n", - " \"inputs\":\n", - " seqio.Feature(\n", - " vocabulary=model.input_vocabulary, add_eos=True),\n", - " \"targets\":\n", - " seqio.Feature(\n", - " vocabulary=model.output_vocabulary, add_eos=True)\n", - " }\n", - "features = dict(sorted(output_features.items()))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ckw3l3Go_ZDL" - }, - "source": [ - "Finally, the InteractiveModel defines a `self.infer_with_preprocessors` method that we will need to reference in order to run evaluation. However, we are breaking down the InteractiveModel functionality and do not actually use an instance of the InteractiveModel in this Colab. Thus, we will duplicate this class method below.\n", - "\n", - "If you don't understand the lines of code below or have any additional questions about how/why we do the steps below, please see the third Colab in our tutorial series: https://github.com/google-research/text-to-text-transfer-transformer/blob/main/README.mdx-colab-inference." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "FpSheT1hATQ2" - }, - "outputs": [], - "source": [ - "def infer_with_preprocessors(\n", - " mode: InferenceType, examples: Sequence[Union[str, dict[str, str]]],\n", - " preprocessors: Sequence[Callable[..., tf.data.Dataset]]) -\u003e _Inferences:\n", - " \"\"\"Infer function.\n", - "\n", - " Args:\n", - " mode: Either 'score' to compute the log likelihood of given targets, or\n", - " 'predict_with_aux' to score and decode targets.\n", - " examples: a single batch of examples that should be transformed into a\n", - " tf.data.Dataset. The examples can either take the form of a string (ex:\n", - " a single input for inference), or a dictionary mapping \"input\"/\"target\"\n", - " to a string containing that element.\n", - " preprocessors: list(callable), an optional list of functions that receive\n", - " a tf.data.Dataset and return a tf.data.Dataset. These will be executed\n", - " sequentially and the final dataset must include features matching\n", - " `features`.\n", - "\n", - " Returns:\n", - " Returns a tuple of predictions/scores and any auxiliary values.\n", - " \"\"\"\n", - " # --------------------------------------------------------------------------\n", - " # Parse Mode\n", - " # --------------------------------------------------------------------------\n", - " if mode == InferenceType.PREDICT_WITH_AUX:\n", - " infer_step = model.predict_batch_with_aux\n", - " elif mode == InferenceType.SCORE:\n", - " infer_step = model.score_batch\n", - " else:\n", - " raise ValueError(\"Mode must be `predict_with_aux`, or `score`,\"\n", - " f\" but instead was {mode}.\")\n", - " infer_fn = functools.partial(\n", - " utils.get_infer_fn(\n", - " infer_step=infer_step,\n", - " batch_size=batch_size,\n", - " train_state_axes=train_state_initializer.train_state_axes,\n", - " partitioner=partitioner),\n", - " train_state=train_state)\n", - "\n", - " # --------------------------------------------------------------------------\n", - " # Construct a dataset and dataset iterator.\n", - " # --------------------------------------------------------------------------\n", - " dataset = get_dataset_from_natural_text_examples(\n", - " examples,\n", - " preprocessors=preprocessors,\n", - " task_feature_lengths=task_feature_lengths,\n", - " features=features)\n", - " feature_converter = model.FEATURE_CONVERTER_CLS(pack=False)\n", - " model_dataset = feature_converter(\n", - " dataset, task_feature_lengths=task_feature_lengths)\n", - " # Zip task and model features.\n", - " infer_dataset = tf.data.Dataset.zip((dataset, model_dataset))\n", - " # Create batches and index them.\n", - " infer_dataset = infer_dataset.padded_batch(\n", - " batch_size, drop_remainder=False).enumerate()\n", - " infer_dataset_iter: Iterator[Tuple[int, Any]] = iter(\n", - " infer_dataset.prefetch(tf.data.experimental.AUTOTUNE))\n", - "\n", - " # --------------------------------------------------------------------------\n", - " # Run inference\n", - " # --------------------------------------------------------------------------\n", - " # Main Loop over \"batches\".\n", - " all_inferences = []\n", - " all_aux_values = {}\n", - " for chunk, chunk_batch in infer_dataset_iter:\n", - " # Load the dataset for the next chunk. We can't use `infer_dataset_iter`\n", - " # directly since `infer_fn` needs to know the exact size of each chunk,\n", - " # which may be smaller for the final one.\n", - " chunk_dataset = tf.data.Dataset.from_tensor_slices(chunk_batch)\n", - " chunk_dataset.cache().prefetch(tf.data.experimental.AUTOTUNE)\n", - "\n", - " # Unzip chunk dataset in to pretokenized and model datasets.\n", - " task_dataset = chunk_dataset.map(\n", - " lambda p, m: p, num_parallel_calls=tf.data.experimental.AUTOTUNE)\n", - " model_dataset = chunk_dataset.map(\n", - " lambda p, m: m, num_parallel_calls=tf.data.experimental.AUTOTUNE)\n", - "\n", - " # Get a chunk-specific RNG key.\n", - " chunk_rng = jax.random.fold_in(jax.random.PRNGKey(0), chunk)\n", - "\n", - " inferences = _extract_tokens_and_aux_values(\n", - " infer_fn(model_dataset.enumerate(), rng=chunk_rng))\n", - "\n", - " predictions, aux_values = inferences\n", - " accumulated_inferences = []\n", - " for idx, inputs in task_dataset.enumerate().as_numpy_iterator():\n", - " prediction = predictions[idx]\n", - " # Decode predictions if applicable.\n", - " if mode == InferenceType.PREDICT_WITH_AUX:\n", - " prediction =features[\"targets\"].vocabulary.decode_tf(\n", - " tf.constant(prediction)).numpy()\n", - " accumulated_inferences.append((inputs, prediction))\n", - " all_inferences += accumulated_inferences\n", - " # Accumulate aux values over batches.\n", - " if not all_aux_values:\n", - " all_aux_values = aux_values\n", - " else:\n", - " for key, values in aux_values.items():\n", - " all_aux_values[key] += values\n", - "\n", - " return all_inferences, all_aux_values" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ib9aOi2xaCKQ" - }, - "source": [ - "# Evaluation Deep Dive" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ANqpfv0lAVqL" - }, - "source": [ - "**Defining a Batch of Examples to Run Inference On**\\\n", - "Let's start by defining a batch of examples that we will use to evaluate our model.\n", - "\n", - "These examples should be a list of dictionaries mapping 'target'/'input' keys to corresponding values, as shown below. For this Colab, we'll use a set of natural test questions and answers." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "yhhR0yDcAn7w" - }, - "outputs": [], - "source": [ - "examples = [\n", - " {\n", - " 'target': b'Ajay Tyagi',\n", - " 'input':b'nq question: who has been appointed as the new chairman of sebi'\n", - " },\n", - " {\n", - " 'target': b'C. S. Lewis',\n", - " 'input': b'nq question: who wrote the book lion the witch and the wardrobe'},\n", - " {\n", - " 'target': b'29',\n", - " 'input': b'nq question: how many planes did japan lose at pearl harbor'},\n", - " {\n", - " 'target': b'Jack Keil',\n", - " 'input': b'nq question: who does the voice of mcgruff the dog'},\n", - " {\n", - " 'target': b'Journey',\n", - " 'input': b'nq question: who sings the wheels in the sky keep on turning'},\n", - " {\n", - " 'target': b'Kumiko Watanabe',\n", - " 'input': b'nq question: who voices regina in glitter force doki doki'},\n", - " {\n", - " 'target': b'during World War II',\n", - " 'input': b'nq question: when did the us become allies with britain'},\n", - " {\n", - " 'target': b'the United States',\n", - " 'input': b'nq question: who won the rugby 7 in las vegas'},\n", - "]" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "WYV1LMS5taE9" - }, - "source": [ - "We also define the required features of the examples. For this Colab, we will only require an `inputs` and `targets` entry, as defined below." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "nj5I7YMotb9U" - }, - "outputs": [], - "source": [ - "output_features = {\n", - " \"inputs\":\n", - " seqio.Feature(\n", - " vocabulary=model.input_vocabulary, add_eos=True),\n", - " \"targets\":\n", - " seqio.Feature(\n", - " vocabulary=model.output_vocabulary, add_eos=True)\n", - " }\n", - "features = dict(sorted(output_features.items()))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "hEG2_HbVGb4y" - }, - "source": [ - "**Defining a Metrics Function**\\\n", - "Next, we'll need to determine what metrics we want to use to evaluate our model, and we'll need to define a metrics function to produce these values.\n", - "\n", - "We support two types of metrics: \\\n", - "1.) *Prediction-based metrics*: these are metrics that depend on model predictions; the metric may also rely on additional auxiliary values. For example, if our model produces an output sequence, a valid prediction-based metric would be BLEU, which compares our output sequence to a target sequence. \\\n", - "2.) *Score-based metrics*: these are metrics the depend on model scores. For example, log likelihood of a target sequence (given an input sequence) would be a valid score-based metrics function.\n", - "\n", - "For more details on metrics function, please see this [Metrics](https://github.com/google/seqio/blob/main/README.md/index#metrics) documentation.\n", - "\n", - "It is rare that you will actually have to define your own metrics function; unless you are working on a custom/novel metric, you can likely find a predefined metrics function to call on. For example, many common language evaluation metrics are defined in [t5.evaluation.metrics](https://github.com/google-research/text-to-text-transfer-transformer/tree/main/t5/evaluation/metrics.py). For this Colab, we will evaluate our natural question/answer pairs using the predefined SQuAD metrics from `t5.evaluation.metrics`." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "vh3NQDwvI5xr" - }, - "outputs": [], - "source": [ - "metric_fns = [t5_metrics.squad]" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "2Y6ebToiR2iF" - }, - "source": [ - "**Defining a Postprocessor Function** \\\n", - "Some metrics functions require postprocessing targets before we are able to calculate the metrics. The InteractiveModel allows users to optionally provide a postprocessor to convert targets to the intended form; see this [Postprocessor](https://github.com/google/seqio/blob/main/README.md/index#postprocessor) documentation for more details.\n", - "\n", - "For this example, we will use a standard QA postprocessor, modeled after the [`t5.data.postprocessors.qa` method](https://github.com/google-research/text-to-text-transfer-transformer/tree/main/t5/data/postprocessors.py)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "Mml43lATSH7D" - }, - "outputs": [], - "source": [ - "def qa(answer, example=None, is_target=False):\n", - " \"\"\"Returns answer, or all answers if the full example is provided.\"\"\"\n", - " if is_target:\n", - " return [tf.compat.as_text(a) for a in [example[\"targets_pretokenized\"]]]\n", - " return answer\n", - "\n", - "postprocessor = qa" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "mixLzcBkQOT_" - }, - "source": [ - "Now, let's break down what the interactive model does to run evaluation.\n", - "\n", - "The `InteractiveModel` `evaluate()` method performs four actions:\n", - "\n", - "\n", - "1. Convert the natural text examples into a tf.Dataset.\n", - "2. Detect the metric function type. We analyze the metrics function signatures to determine if the metrics are prediction-based or score-based.\n", - "3. Run inference to generate predictions and/or scores depending on the metrics function types.\n", - "4. Run the metrics functions on the provided predictions/scores and return these metrics to the user.\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ug0zJx2kQk6g" - }, - "source": [ - "**Prepare the dataset** \\\n", - "\n", - "Preparing the data for evaluation is fairly straightforward; in fact, this is nearly the same data preparation that happens for training.\n", - "\n", - "First, we convert the natural text examples into a tf.Dataset and run any preprocessors; T5X has a helper function, `get_dataset_from_natural_text_examples`, that can do exactly that. For this example, the only preprocessing we will do is tokenization and appending an EOS token. If you are interested in learning more about preprocessors, please take a look at https://github.com/google-research/text-to-text-transfer-transformer/blob/main/README.mdx-colab-intro.\n", - "\n", - "Finally, we may optionally postprocess the targets (if a postprocessor has been provided)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "chPomDFxQ6r3" - }, - "outputs": [], - "source": [ - "preprocessors= [\n", - " seqio.preprocessors.tokenize,\n", - " seqio.preprocessors.append_eos\n", - "]\n", - "dataset = get_dataset_from_natural_text_examples(\n", - " examples,\n", - " preprocessors=preprocessors,\n", - " task_feature_lengths=task_feature_lengths,\n", - " features=features)\n", - "\n", - "# Postprocess targets if required.\n", - "def postprocess_fn(decoded_model_output: Any, **postprocess_kwargs) -\u003e Any:\n", - " \"\"\"Returns the model output after applying the postprocess function.\"\"\"\n", - " if postprocessor:\n", - " return postprocessor(decoded_model_output, **postprocess_kwargs)\n", - " return decoded_model_output\n", - "\n", - "targets = []\n", - "for ex in tfds.as_numpy(dataset):\n", - " targets.append(\n", - " postprocess_fn(\n", - " decoded_model_output=ex[\"targets_pretokenized\"],\n", - " example=ex,\n", - " is_target=True))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "JEsB1uzfSi3L" - }, - "source": [ - "**Parse Metrics Functions** \\\n", - "Next, we inspect the function signature for all metrics functions to determine whether the metrics are prediction-based or score-based. Further, we also detect whether the prediction-based metrics require auxiliary values.\n", - "\n", - "This check is fairly rudimentary; we simply look at the arguments for the metrics functions and categorize the function based on whether \"scores\", \"predictions\", and/or \"aux_values\" appear as arguments to the functions.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "wfLMdo6FS0aD" - }, - "outputs": [], - "source": [ - "predict_metric_fns = []\n", - "predict_with_aux_metric_fns = []\n", - "score_metric_fns = []\n", - "for metric_fn in metric_fns:\n", - " pos_args = tuple(\n", - " key for key, param in inspect.signature(metric_fn).parameters.items()\n", - " if param.default == inspect.Parameter.empty)\n", - " if pos_args == (\"targets\", \"scores\"):\n", - " score_metric_fns.append(metric_fn)\n", - " elif pos_args == (\"targets\", \"predictions\"):\n", - " predict_metric_fns.append(metric_fn)\n", - " elif pos_args == (\"targets\", \"predictions\", \"aux_values\"):\n", - " predict_with_aux_metric_fns.append(metric_fn)\n", - " else:\n", - " raise ValueError(\n", - " \"Metric functions must have positional arguments matching either \"\n", - " \"('targets', 'scores'), ('targets', 'predictions') or \"\n", - " \"('targets', 'predictions', 'aux_values'). \"\n", - " f\"Got: {pos_args}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "_Ynv_W3aS3ai" - }, - "source": [ - "**Run Inference** \\\n", - "Next, we extract predictions and/or scores depending on the types of our metrics functions. We simply use our `infer_with_preprocessors` helper (in the InteractiveModel, we use the `self.infer_with_preprocessors` class method). For more details on inference in the InteractiveModel, please see https://github.com/google-research/text-to-text-transfer-transformer/blob/main/README.mdx-colab-inference." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "ACBEyqlWS8EF" - }, - "outputs": [], - "source": [ - "# Get predictions.\n", - "predictions = []\n", - "if predict_with_aux_metric_fns or predict_metric_fns:\n", - " predictions, aux_values = infer_with_preprocessors(\n", - " mode=InferenceType.PREDICT_WITH_AUX,\n", - " examples=examples,\n", - " preprocessors=preprocessors)\n", - " predictions = [\n", - " prediction.decode(\"utf-8\") for example, prediction in predictions\n", - " ]\n", - "\n", - "# Get scores.\n", - "scores = []\n", - "if score_metric_fns:\n", - " scores, _ = infer_with_preprocessors(\n", - " mode=InferenceType.SCORE,\n", - " examples=examples,\n", - " preprocessors=preprocessors)\n", - " scores = [score for example, score in scores]" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "hTKh1RNNS_BS" - }, - "source": [ - "**Compute Metrics** \\\n", - "Finally, we define and call a helper function to compute metrics given our inputs, predictions/scores, targets, and metrics functions.\n", - "\n", - "This core functionality of this helper is fairly straightforward and is defined in the inner `compute_metrics_fn`. This function simply iterates over all the metrics functions, passing the correct inputs (predictions, scores, and/or auxiliary values) to each metrics function to calculate the value of that metric. We then create a dictionary mapping the metric name to the value of that metric.\n", - "\n", - "There is a bit of logic that wraps around this `compute_metrics_fn` that enables us to run these computations in a multihost environment. In particular, we ensure that we only calculate metrics once, and appropriately wrap `compute_metrics_fn` in a TF computation graph if necessary." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "1UyF8F47TCZI" - }, - "outputs": [], - "source": [ - "def compute_metrics(\n", - " targets: Sequence[Any], predictions: Sequence[Any],\n", - " aux_values: Sequence[Any], scores: Sequence[Any],\n", - " predict_metric_fns: Sequence[seqio.dataset_providers.MetricFnCallable],\n", - " predict_with_aux_metric_fns: Sequence[\n", - " seqio.dataset_providers.MetricFnCallable],\n", - " score_metric_fns: Sequence[seqio.dataset_providers.MetricFnCallable]):\n", - " \"\"\"Computes the metrics specified in the metric_fns lists.\"\"\"\n", - " # Only compute metrics once\n", - " if jax.process_index() != 0:\n", - " return {}\n", - "\n", - " def compute_metrics_fn():\n", - " task_metrics = []\n", - " if predict_metric_fns:\n", - " task_metrics.extend([\n", - " metric_fn(targets, predictions) for metric_fn in predict_metric_fns\n", - " ])\n", - " if predict_with_aux_metric_fns:\n", - " task_metrics.extend([\n", - " metric_fn(targets, predictions, aux_values) for metric_fn in predict_with_aux_metric_fns\n", - " ])\n", - " if score_metric_fns:\n", - " is_tuple = isinstance(scores, tuple)\n", - " if ((not is_tuple and len(targets) != len(scores)) or\n", - " (is_tuple and len(targets) != len(scores[0]))):\n", - " raise ValueError(f\"len(targets)({len(targets)}) != \"\n", - " f\"len(output_scores)({len(scores)})\")\n", - " task_metrics.extend([\n", - " metric_fn(targets, scores) for metric_fn in score_metric_fns\n", - " ])\n", - "\n", - " all_metrics = {}\n", - " for k, v in itertools.chain(*[m.items() for m in task_metrics]):\n", - " if k in all_metrics:\n", - " raise ValueError(f\"Duplicate metric key '{k}' in Task.\")\n", - " all_metrics[k] = v\n", - " return all_metrics\n", - "\n", - " if not tf.executing_eagerly():\n", - " def wrap_graph(fn):\n", - " graph = tf.compat.v1.get_default_graph()\n", - " def wrapped_fn():\n", - " with graph.as_default():\n", - " return fn()\n", - " return wrapped_fn\n", - " compute_metrics_fn = wrap_graph(compute_metrics_fn)\n", - "\n", - " all_metrics = compute_metrics_fn()\n", - " # Wait until computations are done before continuing.\n", - " utils.sync_global_devices(\"Completed.\")\n", - " return all_metrics\n", - "\n", - "\n", - "metrics = compute_metrics(\n", - " targets,\n", - " predictions,\n", - " aux_values,\n", - " scores,\n", - " predict_metric_fns,\n", - " predict_with_aux_metric_fns,\n", - " score_metric_fns)\n", - "print(metrics)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "AB0U_kfRNIyR" - }, - "source": [ - "The code snippets above exactly replicate the `InteractiveModel` `evaluate()` method (see [source code](https://github.com/google-research/t5x/blob/main/t5x/interactive_model.py)); running the code snippets above is exactly equivalent to running `interactive_model.evaluate(examples, preprocessors=[seqio.preprocessors.tokenize, seqio.preprocessors.append_eos], metric_fns=[t5_metrics.squad], postprocessor=t5.data.postprocessors.qa)`." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "lcDwmp_AxnOG" - }, - "source": [ - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "-QR5LnmN4ikp" - }, - "source": [ - "# Advanced Topics" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "CLstCKpP8Ge7" - }, - "source": [ - "## T5X Evaluation Binaries and Other Advanced Features\n", - "\n", - "T5X offers evauation binaries that have the same functionality as the InteractiveModel, with additional features as well (more advanced compiling, etc.). Importantly, these binaries are configured using [Gin](https://github.com/google/gin-config/blob/main/README.md); if you are not familiar with Gin, please take a look at this [Gin Primer](https://github.com/google-research/t5x/blob/main/docs/usage.md/gin) to get started.\n", - "\n", - "If you are familiar with Gin and interested in using the T5X evaluation binaries, we have provided a helper function, `get_gin_config_from_interactive_model`, which will take an InteractiveModel instance and generate the gin config that you can use to run the T5X evaluation binaries; this gin config will exactly reproduce the InteractiveModel evaluation functionality we've described above. We've provided an example below.\n", - "\n", - "Importantly, the InteractiveModel takes in a model, partitioner, and data, so we cannot generate Gin configs for these components. You can pass Gin config strings for the model and partitioner components to the helper function, as demonstrated below. Additionally, you can pass a SeqIO task containing your data to the helper function. See the section below if you are unfamiliar with SeqIO." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "rhgUZ0w6yQsE" - }, - "outputs": [], - "source": [ - "# Define an InteractiveModel instance, based on the `small` T5X EncoderDecoder model.\n", - "t5_config = network.T5Config(\n", - " vocab_size=32128,\n", - " dtype='bfloat16',\n", - " emb_dim=512,\n", - " num_heads=6,\n", - " num_encoder_layers=8,\n", - " num_decoder_layers=8,\n", - " head_dim=64,\n", - " mlp_dim=1024,\n", - " mlp_activations=('gelu', 'linear'),\n", - " dropout_rate=0.0,\n", - " logits_via_embedding=False)\n", - "module = network.Transformer(config=t5_config)\n", - "model = t5x.models.EncoderDecoderModel(\n", - " module=module,\n", - " input_vocabulary=t5.data.get_default_vocabulary(),\n", - " output_vocabulary=t5.data.get_default_vocabulary(),\n", - " optimizer_def=t5x.adafactor.Adafactor(decay_rate=0.8, step_offset=0),\n", - " decode_fn=functools.partial(\n", - " t5x.decoding.temperature_sample, temperature=1.0, topk=40))\n", - "interactive_model = InteractiveModel(\n", - " batch_size=8,\n", - " task_feature_lengths={'inputs': 32, 'targets': 32},\n", - " output_dir='/tmp/',\n", - " partitioner=partitioning.PjitPartitioner(\n", - " num_partitions=1,\n", - " model_parallel_submesh=None,\n", - " logical_axis_rules=partitioning.standard_logical_axis_rules()),\n", - " model=model,\n", - " dtype='float32',\n", - " restore_mode='specific',\n", - " checkpoint_path='',\n", - " input_shapes={\n", - " 'encoder_input_tokens': np.array([8, 38]),\n", - " 'decoder_target_tokens': np.array([8, 18]),\n", - " 'decoder_input_tokens': np.array([8, 18]),\n", - " 'decoder_loss_weights': np.array([8, 18])\n", - " },\n", - " input_types=None)\n", - "\n", - "# Define Gin Config strings for the model, partitioner, and any imports.\n", - "imports_str = \"\"\"from t5x import models\n", - "from t5x import partitioning\n", - "import t5.data.mixtures\n", - "include 't5x/examples/t5/t5_1_1/tiny.gin'\"\"\"\n", - "partitioner_config = 'partitioning.PjitPartitioner.num_partitions = 2'\n", - "model_config = \"\"\"models.EncoderDecoderModel:\n", - "z_loss = 0.0\n", - "label_smoothing = 0.0\n", - "loss_normalizing_factor = None\"\"\"\n", - "\n", - "gin_config_str = get_gin_config_from_interactive_model(\n", - " interactive_model=interactive_model,\n", - " script_type=T5XScriptType.EVALUATION,\n", - " task_name='wmt19_ende_v003',\n", - " partitioner_config_str=partitioner_config,\n", - " model_config_str=model_config,\n", - " train_steps=0,\n", - " imports_str=imports_str,\n", - ")\n", - "print(gin_config_str)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "uGd1DxDT3gB7" - }, - "source": [ - "Once you have generated the `gin_config_str` as above, you can write this string to a file and launch your evaluation experiment locally by running the following on commandline:\n", - "\n", - "\n", - "```\n", - "EVAL_OUTPUT_DIR=\"/tmp/eval-model/\"\n", - "python -m t5x.train_unfragmented \\\n", - " --gin_file=${GIN_FILE_PATH} \\\n", - " --gin.EVAL_OUTPUT_DIR=\\\"${EVAL_OUTPUT_DIR}\\\" \\\n", - " --alsologtostderr\n", - "```\n", - "\n", - "For more details on evaluation using the T5X evaluation binaries, please see the [Evaluation](https://github.com/google-research/t5x/blob/main/docs/usage.md/eval) tutorial." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "wi29fMdv4mSr" - }, - "source": [ - "## SeqIO\n", - "\n", - "If you are interested in T5X, you may also be interested in, or have heard of, SeqIO. SeqIO is a library for processing sequential data to be fed into downstream sequence models. At a high level, SeqIO relies on user-defined `Tasks` and `Mixtures` that can be used to retrieve and evaluate datasets.\n", - "\n", - "We won't go into details about SeqIO here; we recommend checking out this [SeqIO Introductory guide](https://github.com/google/seqio/blob/main/README.md/index) and/or clicking below to run a SeqIO Introductory Colab. The rest of this section will assume a basic understanding of SeqIO.\n", - "\n", - "\u003ca href=\"https://colab.research.google.com/github/google-research/seqio/blob/main/seqio/notebooks/Basics_Task_and_Mixtures.ipynb\" target=\"_parent\"\u003e\u003cimg src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/\u003e\u003c/a\u003e\n", - "\n", - "If you are already familiar with SeqIO and have a SeqIO task/mixture that you would like to use in this Colab, we do provide a SeqIO bridge that takes in a SeqIO task/mixture and produces batches of examples that can be processed by the code snippets above. We've provided an example of this bridge below." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "DLSwblIQ7ZCC" - }, - "outputs": [], - "source": [ - "!git clone https://github.com/google-research/google-research.git google_research" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "bM0nRIEFwyj_" - }, - "outputs": [], - "source": [ - "import google_research.t5_closed_book_qa.t5_cbqa.tasks\n", - "batches = get_batches_from_seqio(\n", - " task_or_mixture_name='natural_questions_open',\n", - " split='validation',\n", - " batch_size=8,\n", - " num_batches=2,\n", - " seed=42)\n", - "print(f\"Batches: {batches}\")\n", - "# Train the interactive model on the provided batches.\n", - "original_step = interactive_model.step\n", - "_ = interactive_model.train_loop(num_steps=len(batches), train_batches=batches)\n", - "print(f\"Original Step: {original_step}, Current Step: {interactive_model.step}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Elt08160w03X" - }, - "source": [ - "The `get_batches_from_seqio` bridge can take several constructor arguments:\n", - "\n", - "\n", - "1. `task_or_mixture_name`: the name of the SeqIO task/mixture to read data from. It should be noted that your task/mixture must already be registered with SeqIO, and you must import the module that defines your task/mixture here (as seen above).\n", - "2. `split`: the split of the Task/Mixture to read data from.\n", - "3. `batch_size`: how many examples should appear in each batch.\n", - "4. `num_batches`: the total number of batches to return.\n", - "5. `get_pretokenized_examples`: optional. A boolean, defaulting to True, that determines whether we should read the `inputs_pretokenized`/`targets_pretokenized` elements from an example, or the `inputs`/`targets` elements. \\\n", - "The `train_step`, `predict`, `predict_with_aux`, `score`, and `evaluate` methods of the InteractiveModel assume that we should run [tokenization](https://github.com/google/seqio/tree/main/seqio/preprocessors.py) and [appending an EOS token](https://github.com/google/seqio/tree/main/seqio/preprocessors.py) as the only preprocessors. To use these methods with this pre-defined list of preprocessors, you can set `get_pretokenized_examples=True` to retrieve examples that still need to be tokenized, and these InteractiveModel methods will handle running these preprocessors. This setting can also be helpful if you want to inspect the natural text inputs/targets of your SeqIO task. \\\n", - "However, some SeqIO tasks do not use tokenization (ex: span corruption). You can set `get_pretokenized_examples=False`, and this bridge will read the fully preprocessed examples from the SeqIO task. You can then run `train_step_with_preprocessors`, `infer_with_preprocessors`, or `evaluate_with_preprocessors` and provide an empty preprocessors list (because all preprocessing has already been completed by this bridge) to run training/inference/evaluation. We have provided an example of using this bridge to retrieve fully preprocessed examples below.\n", - "6. `sequence_length`: optional. A dictionary mapping feature key to maximum length (int) for that feature. Used by SeqIO to retrieve the dataset/examples.\n", - "7. `**get_dataset_kwargs`: there are many [additional parameters](https://github.com/google/seqio/tree/main/seqio/dataset_providers.py) that can be set in the `SeqIO.get_dataset` function. If you would like to set any of these arguments, you can set them using this `kwargs` parameter.\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "fjKBCX39w0Xl" - }, - "outputs": [], - "source": [ - "import t5.data.tasks\n", - "batches = get_batches_from_seqio(\n", - " task_or_mixture_name='c4_v220_span_corruption',\n", - " split='validation',\n", - " batch_size=8,\n", - " num_batches=1,\n", - " get_pretokenized_examples=False,\n", - " sequence_length=interactive_model._task_feature_lengths,\n", - " seed=42)\n", - "batch = batches[0] # We expect only a single batch.\n", - "original_step = interactive_model.step\n", - "interactive_model.train_step_with_preprocessors(\n", - " examples=batch, preprocessors=[])\n", - "print(f\"Original Step: {original_step}, Current Step: {interactive_model.step}\")" - ] - } - ], - "metadata": { - "colab": { - "collapsed_sections": [ - "bqZYp90PIa1t", - "lcDwmp_AxnOG" - ], - "last_runtime": { - "build_target": "//learning/grp/tools/ml_python:ml_notebook", - "kind": "private" - }, - "name": "Welcome to T5X: Evaluation Deep Dive", - "private_outputs": true, - "provenance": [ - { - "file_id": "18IRHbzIplnXwxF2ii10vFsqyRhPKBcWA", - "timestamp": 1676344856728 - }, - { - "file_id": "1hQO9MD6psZtTeqZyXPJIoUV0uzTa2qPg", - "timestamp": 1662951508591 - }, - { - "file_id": "1Akpc6pKlJB5rn5YYYFC9lw2OMk6oBzlQ", - "timestamp": 1662754223629 - }, - { - "file_id": "1rA8bgO2bJRoebAuS96Ji0RUhnawgBY4i", - "timestamp": 1650477076639 - } - ] - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "name": "python" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} diff --git a/t5x-main/t5x/notebooks/inference.ipynb b/t5x-main/t5x/notebooks/inference.ipynb deleted file mode 100644 index 7706ab958d8b5e68dae68fe9319567cab739f8f7..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/notebooks/inference.ipynb +++ /dev/null @@ -1,778 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "-0BQWhvAP2jb" - }, - "source": [ - "\n", - "\u003ca href=\"https://colab.research.google.com/github/google-research/t5x/blob/main/t5x/notebooks/inference.ipynb\" target=\"_parent\"\u003e\u003cimg src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/\u003e\u003c/a\u003e" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "bqZYp90PIa1t" - }, - "source": [ - "# Overview\n", - "\n", - "This is the third Colab in a [series of tutorials on how to use T5X](https://github.com/google-research/t5x/blob/main/docs/tutorials.md). We assume that you have already completed the [Introductory Colab](https://colab.research.google.com/github/google-research/t5x/blob/main/t5x/notebooks/introduction.ipynb) and the [Training Deep Dive](https://colab.research.google.com/github/google-research/t5x/blob/main/t5x/notebooks/training.ipynb), or have a basic understanding of the T5X models, checkpoints, partitioner, trainer, and `InteractiveModel`.\n", - "\n", - "In the [previous Colab](https://colab.research.google.com/github/google-research/t5x/blob/main/t5x/notebooks/training.ipynb) in this tutorial series, we dove into how the InteractiveModel restores models from checkpoints and runs training, while also getting an introduction to the T5X trainer. In this Colab, we will focus on how the `InteractiveModel` does decoding to generate predictions and scores for a given input. It should be noted that the code snippets below exactly replicate the InteractiveModel `__init__()` and `infer_with_preprocessors()` methods (see [source code](https://github.com/google-research/t5x/blob/main/t5x/interactive_model.py)); we expose this functionality here in order to demonstrate how various components of the T5X codebase work together to run inference on a model." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "nZJbWZcfkyxI" - }, - "source": [ - "# Set-Up" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "8my9yhSRi6GG" - }, - "source": [ - "Note: If you are a using public colab, please use its `Connect to a local runtime` option by following the [setup guide](https://github.com/google-research/t5x/blob/main/t5x/notebooks/README.md)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "jIGSIHzD7YPO" - }, - "outputs": [], - "source": [ - "from collections.abc import Sequence\n", - "import enum\n", - "import functools\n", - "import inspect\n", - "import itertools\n", - "import logging\n", - "import os\n", - "import re\n", - "from typing import Any, Callable, Iterator, Optional, Tuple, Union\n", - "\n", - "import jax\n", - "from jax import random\n", - "from jax.experimental import multihost_utils\n", - "import numpy as np\n", - "import seqio\n", - "import tensorflow as tf\n", - "import tensorflow_datasets as tfds\n", - "import t5.data\n", - "import t5.data" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "yNtayQIxjEBd" - }, - "outputs": [], - "source": [ - "import clu.data\n", - "from t5x.examples.t5 import network\n", - "import t5x\n", - "from t5x import models\n", - "from t5x import partitioning\n", - "from t5x import trainer as trainer_lib\n", - "from t5x import utils\n", - "from t5x.infer import _extract_tokens_and_aux_values\n", - "from t5x.infer import _Inferences\n", - "from t5x.interactive_model import InteractiveModel\n", - "from t5x.interactive_model import get_batches_from_seqio\n", - "from t5x.interactive_model import get_dataset_from_natural_text_examples\n", - "from t5x.interactive_model import get_gin_config_from_interactive_model\n", - "from t5x.interactive_model import T5XScriptType\n", - "from t5x.interactive_model import InferenceType" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "S5Lb-Z1fkF5a" - }, - "source": [ - "Before we begin, let's initialize instances of the constructor arguments for the `InteractiveModel`. As mentioned previously, this will enable us to dive into how the `InteractiveModel` runs inference.\n", - "\n", - "If you don't understand the lines of code below, or have questions about how to initialize these parameters, please see the [first Colab](https://colab.research.google.com/github/google-research/t5x/blob/main/t5x/notebooks/introduction.ipynb) in this tutorial series." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "Ne8U8qoWkX_r" - }, - "outputs": [], - "source": [ - "# Define a model. The configuration below corresponds to the T5 1.1 Small model.\n", - "t5_config = network.T5Config(\n", - " vocab_size=32128,\n", - " dtype='bfloat16',\n", - " emb_dim=512,\n", - " num_heads=6,\n", - " num_encoder_layers=8,\n", - " num_decoder_layers=8,\n", - " head_dim=64,\n", - " mlp_dim=1024,\n", - " mlp_activations=('gelu', 'linear'),\n", - " dropout_rate=0.0,\n", - " logits_via_embedding=False)\n", - "module = network.Transformer(config=t5_config)\n", - "model = t5x.models.EncoderDecoderModel(\n", - " module=module,\n", - " input_vocabulary=t5.data.get_default_vocabulary(),\n", - " output_vocabulary=t5.data.get_default_vocabulary(),\n", - " optimizer_def=t5x.adafactor.Adafactor(decay_rate=0.8, step_offset=0))\n", - "# Define checkpoint arguments.\n", - "checkpoint_path='gs://t5-data/pretrained_models/cbqa/small_ssm_nq/model.ckpt-1110000'\n", - "dtype='bfloat16'\n", - "restore_mode='specific'\n", - "# Define a partitioner.\n", - "partitioner=partitioning.PjitPartitioner(num_partitions=2)\n", - "# Define additional, miscellaneous constructor arguments.\n", - "batch_size=8\n", - "task_feature_lengths = {'inputs': 38, 'targets': 18}\n", - "output_dir='/tmp/output_dir'\n", - "input_shapes = {\n", - " 'encoder_input_tokens': np.array([8, 38]),\n", - " 'decoder_target_tokens': np.array([8, 18]),\n", - " 'decoder_input_tokens': np.array([8, 18]),\n", - " 'decoder_loss_weights': np.array([8, 18])\n", - "}" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "EYwdg-fFTU8Q" - }, - "source": [ - "In addition, we will run all code that is performed when we initialize the InteractiveModel. If you don't understand the lines of code below or have any additional questions about how/why we do the steps below, please see the [second Colab](https://colab.research.google.com/github/google-research/t5x/blob/main/t5x/notebooks/training.ipynb) in this tutorial series." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "YmGTJBAcTpMR" - }, - "outputs": [], - "source": [ - "# 1.) Configure the Output Directory\n", - "output_dir = re.sub(r\"(?\u003c!gs:)([\\/]{2,})\", \"/\", output_dir)\n", - "if not os.path.exists(output_dir):\n", - " os.mkdir(output_dir)\n", - "\n", - "# 2.) Initialize RNGs\n", - "init_random_seed = 42\n", - "random_seed = multihost_utils.broadcast_one_to_all(np.int32(init_random_seed))\n", - "utils.set_hardware_rng_ops()\n", - "rng = random.PRNGKey(random_seed)\n", - "init_rng, trainer_rng = random.split(rng, 2)\n", - "\n", - "# 3.) Validate the Partitioner\n", - "if partitioner._model_parallel_submesh:\n", - " num_partitions = np.prod(partitioner._model_parallel_submesh)\n", - "else:\n", - " num_partitions = partitioner._num_partitions\n", - "if jax.device_count() % num_partitions != 0:\n", - " raise ValueError(\n", - " \"The number of devices available must be a multiple of the number of\",\n", - " f\" partitions. There are {jax.device_count()} devices available, but\",\n", - " f\" the number of partitions is set to {num_partitions}. Please\",\n", - " \" provide a different number of partitions.\")\n", - "\n", - "# 4.) Create a Checkpoint Manager\n", - "# a.) Define CheckpointCfg wrappers.\n", - "save_checkpoint_cfg = utils.SaveCheckpointConfig(\n", - " dtype=dtype,\n", - " keep=5, # The number of checkpoints to keep in the output_dir.\n", - " save_dataset=False)\n", - "restore_checkpoint_cfg = utils.RestoreCheckpointConfig(\n", - " dtype=dtype,\n", - " mode=restore_mode,\n", - " path=checkpoint_path)\n", - "\n", - "# b.) Define a train state initializer, which will help us get information about the\n", - "# TrainState shape.\n", - "train_state_initializer = utils.TrainStateInitializer(\n", - " optimizer_def=model.optimizer_def,\n", - " init_fn=model.get_initial_variables,\n", - " input_shapes=input_shapes,\n", - " input_types=None,\n", - " partitioner=partitioner)\n", - "\n", - "# c.) Define the checkpoint manager.\n", - "checkpoint_manager = utils.LegacyCheckpointManager(\n", - " save_cfg=save_checkpoint_cfg,\n", - " restore_cfg=restore_checkpoint_cfg,\n", - " train_state_shape=train_state_initializer.global_train_state_shape,\n", - " partitioner=partitioner,\n", - " ds_iter=None,\n", - " model_dir=output_dir)\n", - "\n", - "### 5.) Restore the Model from a Checkpoint, or Initialize from Scratch ###\n", - "def get_state(rng):\n", - " return train_state_initializer.from_scratch(rng).state_dict()\n", - "\n", - "# a.) Try to restore a model from a checkpoint.\n", - "train_state = checkpoint_manager.restore(\n", - " [restore_checkpoint_cfg.path],\n", - " restore_checkpoint_cfg,\n", - " utils.get_fallback_state(restore_checkpoint_cfg, get_state, init_rng)\n", - ")\n", - "\n", - "# b.) If no checkpoint to restore, init from scratch.\n", - "if train_state is None:\n", - " train_state = train_state_initializer.from_scratch(init_rng)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ib9aOi2xaCKQ" - }, - "source": [ - "# Inference Deep Dive" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ANqpfv0lAVqL" - }, - "source": [ - "**Defining a Batch of Examples to Run Inference On**\\\n", - "Let's start by defining a batch of examples that we will get predictions and scores for.\n", - "\n", - "These examples should be a list of inputs; we don't need any targets, because we will eventually generate predictions. For this Colab, we'll use a set of natural text questions (and we will generate the answers)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "yhhR0yDcAn7w" - }, - "outputs": [], - "source": [ - "examples = [\n", - " b'nq question: who has been appointed as the new chairman of sebi',\n", - " b'nq question: who wrote the book lion the witch and the wardrobe',\n", - " b'nq question: how many planes did japan lose at pearl harbor',\n", - " b'nq question: who does the voice of mcgruff the dog',\n", - " b'nq question: who sings the wheels in the sky keep on turning',\n", - " b'nq question: who voices regina in glitter force doki doki',\n", - " b'nq question: when did the us become allies with britain',\n", - " b'nq question: who won the rugby 7 in las vegas'\n", - "]" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "WYV1LMS5taE9" - }, - "source": [ - "We also define the required features of the examples. For this Colab, we will only require an `inputs` and `targets` entry, as defined below. `targets` will be empty for our examples, because we do not have any targets to provide at inference time." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "nj5I7YMotb9U" - }, - "outputs": [], - "source": [ - "output_features = {\n", - " \"inputs\":\n", - " seqio.Feature(\n", - " vocabulary=model.input_vocabulary, add_eos=True),\n", - " \"targets\":\n", - " seqio.Feature(\n", - " vocabulary=model.output_vocabulary, add_eos=True)\n", - " }\n", - "features = dict(sorted(output_features.items()))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "IZuleutswi1H" - }, - "source": [ - "Finally, we'll have to determine whether we want to get predictions or scores for this batch. For this example, we'll get predictions, which we'll denote by setting an inference mode variable to `PREDICT_WITH_AUX`." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "PbMPt5eBw-a4" - }, - "outputs": [], - "source": [ - "mode = InferenceType.PREDICT_WITH_AUX\n", - "# Try replacing this variable with `InferenceType.SCORE` to produce scores." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "mixLzcBkQOT_" - }, - "source": [ - "Now, let's break down what the interactive model does to run inference.\n", - "\n", - "The `InteractiveModel` `infer_with_preprocessors()` method only performs three actions:\n", - "\n", - "\n", - "1. Convert the natural text examples into a tf.Dataset.\n", - "2. Define an `infer_fn`; depending on whether we want predictions or scores, this function will be equivalent to `model.predict_batch` or `model.score_batch`.\n", - "3. Extract inferences and return them.\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ug0zJx2kQk6g" - }, - "source": [ - "**Prepare the dataset** \\\n", - "\n", - "Preparing the data for inference is fairly straightforward; in fact, this is nearly the same data preparation that happens for training.\n", - "\n", - "First, we convert the natural text examples into a tf.Dataset and run any preprocessors; T5X has a helper function, `get_dataset_from_natural_text_examples`, that can do exactly that. For this example, the only preprocessing we will do is tokenization and appending an EOS token. If you are interested in learning more about preprocessors, please take a look at https://github.com/google-research/text-to-text-transfer-transformer/blob/main/README.mdx-colab-intro.\n", - "\n", - "Finally, we convert all features using the model's feature converter, pad all batches of data, and define an iterator over our data (this allows us to run inference on multiple batches of examples)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "chPomDFxQ6r3" - }, - "outputs": [], - "source": [ - "dataset = get_dataset_from_natural_text_examples(\n", - " examples,\n", - " preprocessors=[\n", - " seqio.preprocessors.tokenize,\n", - " seqio.preprocessors.append_eos\n", - " ],\n", - " task_feature_lengths=task_feature_lengths,\n", - " features=features)\n", - "feature_converter = model.FEATURE_CONVERTER_CLS(pack=False)\n", - "model_dataset = feature_converter(\n", - " dataset, task_feature_lengths=task_feature_lengths)\n", - "# Zip task and model features.\n", - "infer_dataset = tf.data.Dataset.zip((dataset, model_dataset))\n", - "# Create batches and index them.\n", - "infer_dataset = infer_dataset.padded_batch(\n", - " batch_size, drop_remainder=False).enumerate()\n", - "infer_dataset_iter: Iterator[Tuple[int, Any]] = iter(\n", - " infer_dataset.prefetch(tf.data.experimental.AUTOTUNE))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "toq2uz7dAfdL" - }, - "source": [ - "**Define Infer Function** \\\n", - "\n", - "We'll define a helper function that runs inference on a single batch, making it easy to loop over this helper and run inference for multiple batches. This `infer_fn` can either get predictions or scores, depending on the mode we've previously set.\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "Qv8RXHeXK6mk" - }, - "outputs": [], - "source": [ - "if mode == InferenceType.PREDICT_WITH_AUX:\n", - " infer_step = model.predict_batch_with_aux\n", - "elif mode == InferenceType.SCORE:\n", - " infer_step = model.score_batch\n", - "else:\n", - " raise ValueError(\"Mode must be `predict_with_aux`, or `score`,\"\n", - " f\" but instead was {mode}.\")\n", - "infer_fn = functools.partial(\n", - " utils.get_infer_fn(\n", - " infer_step=infer_step,\n", - " batch_size=batch_size,\n", - " train_state_axes=train_state_initializer.train_state_axes,\n", - " partitioner=partitioner),\n", - " train_state=train_state)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "c-BoczLLzaGb" - }, - "source": [ - "**Extract Inferences** \\\n", - "\n", - "Finally, we will extract inferences for each batch of examples provided. For each batch, we:\n", - "\n", - "1. Unzip the dataset to get both the task dataset and the model dataset (the model dataset is what you get when you've passed the task dataset through the model feature converter).\n", - "2. Get an RNG for the batch.\n", - "3. Extract predictions and auxiliary values using the T5X helper, `_extract_tokens_and_aux_values`.\n", - "4. Decode the predictions using our vocabulary.\n", - "5. Accumulate predictions, aux values, and inputs across all of our batches.\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "GDAyEz9mzghq" - }, - "outputs": [], - "source": [ - "# Main Loop over \"batches\".\n", - "all_inferences = []\n", - "all_aux_values = {}\n", - "for chunk, chunk_batch in infer_dataset_iter:\n", - " # Load the dataset for the next chunk. We can't use `infer_dataset_iter`\n", - " # directly since `infer_fn` needs to know the exact size of each chunk,\n", - " # which may be smaller for the final one.\n", - " chunk_dataset = tf.data.Dataset.from_tensor_slices(chunk_batch)\n", - " chunk_dataset.cache().prefetch(tf.data.experimental.AUTOTUNE)\n", - "\n", - " # Unzip chunk dataset in to pretokenized and model datasets.\n", - " task_dataset = chunk_dataset.map(\n", - " lambda p, m: p, num_parallel_calls=tf.data.experimental.AUTOTUNE)\n", - " model_dataset = chunk_dataset.map(\n", - " lambda p, m: m, num_parallel_calls=tf.data.experimental.AUTOTUNE)\n", - "\n", - " # Get a chunk-specific RNG key.\n", - " chunk_rng = jax.random.fold_in(jax.random.PRNGKey(0), chunk)\n", - "\n", - " inferences = _extract_tokens_and_aux_values(\n", - " infer_fn(model_dataset.enumerate(), rng=chunk_rng))\n", - "\n", - " predictions, aux_values = inferences\n", - " accumulated_inferences = []\n", - " for idx, inputs in task_dataset.enumerate().as_numpy_iterator():\n", - " prediction = predictions[idx]\n", - " # Decode predictions if applicable.\n", - " if mode == InferenceType.PREDICT_WITH_AUX:\n", - " prediction = features[\"targets\"].vocabulary.decode_tf(\n", - " tf.constant(prediction)).numpy()\n", - " accumulated_inferences.append((inputs, prediction))\n", - " all_inferences += accumulated_inferences\n", - " # Accumulate aux values over batches.\n", - " if not all_aux_values:\n", - " all_aux_values = aux_values\n", - " else:\n", - " for key, values in aux_values.items():\n", - " all_aux_values[key] += values\n", - "print(all_inferences)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "rIzrCUyWQcDZ" - }, - "source": [ - "We can parse these predictions into a more readable format using the code below." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "f9_BPXG_QgJs" - }, - "outputs": [], - "source": [ - "for input, prediction in all_inferences:\n", - " print(f\"Input: {input['inputs_pretokenized']}\")\n", - " print(f\"Prediction: {prediction}\\n\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "AB0U_kfRNIyR" - }, - "source": [ - "The code snippets above exactly replicate the `InteractiveModel` `infer_with_preprocessors()` method (see [source code](https://github.com/google-research/t5x/blob/main/t5x/interactive_model.py)); running the code snippets above is exactly equivalent to running `interactive_model.infer_with_preprocessors(mode, examples, preprocessors=[seqio.preprocessors.tokenize, seqio.preprocessors.append_eos])`." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "lcDwmp_AxnOG" - }, - "source": [ - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "-QR5LnmN4ikp" - }, - "source": [ - "# Advanced Topics" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "CLstCKpP8Ge7" - }, - "source": [ - "## T5X Inference Binaries and Other Advanced Features\n", - "\n", - "T5X offers inference binaries that have the same functionality as the InteractiveModel, with additional features as well (more advanced compiling, inference on TF Example files, prediction services, etc.). Importantly, these binaries are configured using [Gin](https://github.com/google/gin-config/blob/main/README.md); if you are not familiar with Gin, please take a look at this [Gin Primer](https://github.com/google-research/t5x/blob/main/docs/usage.md/gin) to get started.\n", - "\n", - "If you are familiar with Gin and interested in using the T5X inference binaries, we have provided a helper function, get_gin_config_from_interactive_model, which will take an InteractiveModel instance and generate the gin config that you can use to run the T5X inference binaries; this gin config will exactly reproduce the InteractiveModel inference functionality we've described above. We've provided an example below.\n", - "\n", - "Importantly, the InteractiveModel takes in a model, partitioner, and data, so we cannot generate Gin configs for these components. You can pass Gin config strings for the model and partitioner components to the helper function, as demonstrated below. Additionally, you can pass a SeqIO task containing your data to the helper function. See the section below if you are unfamiliar with SeqIO." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "nqa49ZpjnRN1" - }, - "outputs": [], - "source": [ - "!git clone https://github.com/google-research/google-research.git google_research" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "rhgUZ0w6yQsE" - }, - "outputs": [], - "source": [ - "# Define an InteractiveModel instance, based on the `small` T5X EncoderDecoder model.\n", - "t5_config = network.T5Config(\n", - " vocab_size=32128,\n", - " dtype='bfloat16',\n", - " emb_dim=512,\n", - " num_heads=6,\n", - " num_encoder_layers=8,\n", - " num_decoder_layers=8,\n", - " head_dim=64,\n", - " mlp_dim=1024,\n", - " mlp_activations=('gelu', 'linear'),\n", - " dropout_rate=0.0,\n", - " logits_via_embedding=False)\n", - "module = network.Transformer(config=t5_config)\n", - "model = t5x.models.EncoderDecoderModel(\n", - " module=module,\n", - " input_vocabulary=t5.data.get_default_vocabulary(),\n", - " output_vocabulary=t5.data.get_default_vocabulary(),\n", - " optimizer_def=t5x.adafactor.Adafactor(decay_rate=0.8, step_offset=0),\n", - " decode_fn=functools.partial(\n", - " t5x.decoding.temperature_sample, temperature=1.0, topk=40))\n", - "interactive_model = InteractiveModel(\n", - " batch_size=8,\n", - " task_feature_lengths={'inputs': 38, 'targets': 18},\n", - " output_dir='/tmp/output_dir',\n", - " partitioner=partitioning.PjitPartitioner(\n", - " num_partitions=1,\n", - " model_parallel_submesh=None,\n", - " logical_axis_rules=partitioning.standard_logical_axis_rules()),\n", - " model=model,\n", - " dtype='bfloat16',\n", - " restore_mode='specific',\n", - " checkpoint_path='gs://t5-data/pretrained_models/cbqa/small_ssm_nq/model.ckpt-1110000',\n", - " input_shapes={\n", - " 'encoder_input_tokens': np.array([8, 38]),\n", - " 'decoder_target_tokens': np.array([8, 18]),\n", - " 'decoder_input_tokens': np.array([8, 18]),\n", - " 'decoder_loss_weights': np.array([8, 18])\n", - " },\n", - " input_types=None)\n", - "\n", - "# Define Gin Config strings for the model, partitioner, and any imports.\n", - "imports_str = \"\"\"from t5x import models\n", - "from t5x import partitioning\n", - "import t5.data.mixtures\n", - "include 't5x/examples/t5/t5_1_1/small.gin'\n", - "\n", - "# Register necessary SeqIO Tasks/Mixtures.\n", - "import google_research.t5_closed_book_qa.t5_cbqa.tasks\"\"\"\n", - "partitioner_config = 'partitioning.PjitPartitioner.num_partitions = 2'\n", - "\n", - "gin_config_str = get_gin_config_from_interactive_model(\n", - " interactive_model=interactive_model,\n", - " script_type=T5XScriptType.INFERENCE,\n", - " task_name='closed_book_qa',\n", - " partitioner_config_str=partitioner_config,\n", - " model_config_str='', # No config needed, since we just import the model.\n", - " imports_str=imports_str,\n", - ")\n", - "print(gin_config_str)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "uGd1DxDT3gB7" - }, - "source": [ - "Once you have generated the `gin_config_str` as above, you can write this string to a file and launch your inference experiment locally by running the following on commandline:\n", - "\n", - "\n", - "```\n", - "INFER_OUTPUT_DIR=\"/tmp/inference-model/\"\n", - "python -m t5x.infer_unfragmented \\\n", - " --gin_file=${GIN_FILE_PATH} \\\n", - " --gin.INFER_OUTPUT_DIR=\\\"${INFER_OUTPUT_DIR}\\\" \\\n", - " --alsologtostderr\n", - "```\n", - "For more details on inference using the T5X inference binaries, please see the [Inference](https://github.com/google-research/t5x/blob/main/docs/usage.md/infer-seqio) tutorial." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "wi29fMdv4mSr" - }, - "source": [ - "## SeqIO\n", - "\n", - "If you are interested in T5X, you may also be interested in, or have heard of, SeqIO. SeqIO is a library for processing sequential data to be fed into downstream sequence models. At a high level, SeqIO relies on user-defined `Tasks` and `Mixtures` that can be used to retrieve and evaluate datasets.\n", - "\n", - "We won't go into details about SeqIO here; we recommend checking out this [SeqIO Introductory guide](https://github.com/google/seqio/blob/main/README.md/index) and/or clicking below to run a SeqIO Introductory Colab. The rest of this section will assume a basic understanding of SeqIO.\n", - "\n", - "\u003ca href=\"https://colab.research.google.com/github/google-research/seqio/blob/main/seqio/notebooks/Basics_Task_and_Mixtures.ipynb\" target=\"_parent\"\u003e\u003cimg src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/\u003e\u003c/a\u003e\n", - "\n", - "If you are already familiar with SeqIO and have a SeqIO task/mixture that you would like to use in this Colab, we do provide a SeqIO bridge that takes in a SeqIO task/mixture and produces batches of examples that can be processed by the code snippets above. We've provided an example of this bridge below." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "bM0nRIEFwyj_" - }, - "outputs": [], - "source": [ - "import google_research.t5_closed_book_qa.t5_cbqa.tasks\n", - "batches = get_batches_from_seqio(\n", - " task_or_mixture_name='natural_questions_open',\n", - " split='validation',\n", - " batch_size=8,\n", - " num_batches=2,\n", - " seed=42)\n", - "print(f\"Batches: {batches}\")\n", - "# Train the interactive model on the provided batches.\n", - "original_step = interactive_model.step\n", - "_ = interactive_model.train_loop(num_steps=len(batches), train_batches=batches)\n", - "print(f\"Original Step: {original_step}, Current Step: {interactive_model.step}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Elt08160w03X" - }, - "source": [ - "The `get_batches_from_seqio` bridge can take several constructor arguments:\n", - "\n", - "\n", - "1. `task_or_mixture_name`: the name of the SeqIO task/mixture to read data from. It should be noted that your task/mixture must already be registered with SeqIO, and you must import the module that defines your task/mixture here (as seen above).\n", - "2. `split`: the split of the Task/Mixture to read data from.\n", - "3. `batch_size`: how many examples should appear in each batch.\n", - "4. `num_batches`: the total number of batches to return.\n", - "5. `get_pretokenized_examples`: optional. A boolean, defaulting to True, that determines whether we should read the `inputs_pretokenized`/`targets_pretokenized` elements from an example, or the `inputs`/`targets` elements. \\\n", - "The `train_step`, `predict`, `predict_with_aux`, `score`, and `evaluate` methods of the InteractiveModel assume that we should run [tokenization](https://github.com/google/seqio/tree/main/seqio/preprocessors.py) and [appending an EOS token](https://github.com/google/seqio/tree/main/seqio/preprocessors.py) as the only preprocessors. To use these methods with this pre-defined list of preprocessors, you can set `get_pretokenized_examples=True` to retrieve examples that still need to be tokenized, and these InteractiveModel methods will handle running these preprocessors. This setting can also be helpful if you want to inspect the natural text inputs/targets of your SeqIO task. \\\n", - "However, some SeqIO tasks do not use tokenization (ex: span corruption). You can set `get_pretokenized_examples=False`, and this bridge will read the fully preprocessed examples from the SeqIO task. You can then run `train_step_with_preprocessors`, `infer_with_preprocessors`, or `evaluate_with_preprocessors` and provide an empty preprocessors list (because all preprocessing has already been completed by this bridge) to run training/inference/evaluation. We have provided an example of using this bridge to retrieve fully preprocessed examples below.\n", - "6. `sequence_length`: optional. A dictionary mapping feature key to maximum length (int) for that feature. Used by SeqIO to retrieve the dataset/examples.\n", - "7. `**get_dataset_kwargs`: there are many [additional parameters](https://github.com/google/seqio/tree/main/seqio/dataset_providers.py) that can be set in the `SeqIO.get_dataset` function. If you would like to set any of these arguments, you can set them using this `kwargs` parameter.\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "fjKBCX39w0Xl" - }, - "outputs": [], - "source": [ - "import t5.data.tasks\n", - "batches = get_batches_from_seqio(\n", - " task_or_mixture_name='c4_v220_span_corruption',\n", - " split='validation',\n", - " batch_size=8,\n", - " num_batches=1,\n", - " get_pretokenized_examples=False,\n", - " sequence_length=interactive_model._task_feature_lengths,\n", - " seed=42)\n", - "batch = batches[0] # We expect only a single batch.\n", - "original_step = interactive_model.step\n", - "interactive_model.train_step_with_preprocessors(\n", - " examples=batch, preprocessors=[])\n", - "print(f\"Original Step: {original_step}, Current Step: {interactive_model.step}\")" - ] - } - ], - "metadata": { - "colab": { - "last_runtime": { - "build_target": "//learning/grp/tools/ml_python:ml_notebook", - "kind": "private" - }, - "name": "Welcome to T5X: Inference Deep Dive", - "private_outputs": true, - "provenance": [ - { - "file_id": "1ZyKqEf1xpxGUxX9JeQ0VN94MQlWGFrHf", - "timestamp": 1676340117712 - }, - { - "file_id": "1hQO9MD6psZtTeqZyXPJIoUV0uzTa2qPg", - "timestamp": 1662951508591 - }, - { - "file_id": "1Akpc6pKlJB5rn5YYYFC9lw2OMk6oBzlQ", - "timestamp": 1662754223629 - }, - { - "file_id": "1rA8bgO2bJRoebAuS96Ji0RUhnawgBY4i", - "timestamp": 1650477076639 - } - ] - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "name": "python" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} diff --git a/t5x-main/t5x/notebooks/introduction.ipynb b/t5x-main/t5x/notebooks/introduction.ipynb deleted file mode 100644 index 40f9a4c235638737119eeb940d13973b8402c1bc..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/notebooks/introduction.ipynb +++ /dev/null @@ -1,746 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "-0BQWhvAP2jb" - }, - "source": [ - "\n", - "\u003ca href=\"https://colab.research.google.com/github/google-research/t5x/blob/main/t5x/notebooks/introduction.ipynb\" target=\"_parent\"\u003e\u003cimg src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/\u003e\u003c/a\u003e" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "bqZYp90PIa1t" - }, - "source": [ - "# Overview\n", - "\n", - "T5X is a modular, composable, research-friendly framework for high-performance, configurable, self-service training, evaluation, and inference of sequence models (starting with language) at many scales.\n", - "\n", - "It is essentially a new and improved implementation of the [T5 codebase](https://github.com/google-research/text-to-text-transfer-transformer/blob/main/README.md) (based on Mesh TensorFlow) in JAX and Flax.\n", - "\n", - "# Getting Started\n", - "\n", - "In the following Colab, we present an introductory tutorial to get you started interacting with the T5X codebase. In particular, we'll introduce the major components of the T5X codebase and get you started running training, inference, and evaluation on natural text inputs.\n", - "\n", - "Note: If you are a using public colab, please use its `Connect to a local runtime` option by following the [setup guide](https://github.com/google-research/t5x/blob/main/t5x/notebooks/README.md)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "jIGSIHzD7YPO" - }, - "outputs": [], - "source": [ - "import functools\n", - "import os\n", - "\n", - "import clu.data.dataset_iterator\n", - "import tensorflow as tf\n", - "import jax\n", - "from jax import random\n", - "from jax.experimental import multihost_utils\n", - "import jax.numpy as jnp\n", - "from flax import linen\n", - "import numpy as np\n", - "import seqio\n", - "import t5.data\n", - "from t5.evaluation import metrics as t5_metrics" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "q2NtKT-vlomE" - }, - "outputs": [], - "source": [ - "import t5x\n", - "from t5x import partitioning\n", - "from t5x import train_state as train_state_lib\n", - "from t5x import utils\n", - "from t5x.examples.t5 import network\n", - "from t5x.examples.scalable_t5 import network as scalable_network\n", - "from t5x.interactive_model import InteractiveModel\n", - "from t5x.interactive_model import get_batches_from_seqio\n", - "from t5x.interactive_model import InferenceType" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ib9aOi2xaCKQ" - }, - "source": [ - "# T5X Components" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "-trKAuOWaGoB" - }, - "source": [ - "Let's start by going over some of the major components of the T5X codebase: models, checkpoints, and partitioners.\n", - "\n", - "We will define instances of some of these components in the following subsections before we use them to run training, inference, and evaluation." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "RE_CQr9Hcr1D" - }, - "source": [ - "## T5X Models\n", - "One of the primary contributions of the T5X codebase is its easy-to-use collection of models.\n", - "\n", - "The T5X codebase provides an abstract base class, [`BaseModel`](https://github.com/google-research/t5x/blob/main/t5x/models.py), which should be subclassed to define specific model architectures. This abstraction allows us to flexibly extend the T5X framework to custom architectures. Importantly, the `BaseModel` and all subclasses are free from parallelism-related features (this is handled by the partitioner; see following sections).\n", - "\n", - "The T5X codebase also provides several widely-used subclasses of the `BaseModel`, namely the [`EncoderDecoderModel`](https://github.com/google-research/t5x/blob/main/t5x/models.py) and the [`DecoderModel`](https://github.com/google-research/t5x/blob/main/t5x/models.py).\n", - "\n", - "Importantly, the proposed structure of the `BaseModel`/all subclasses does not impose that the model be implemented in a specific framework. Instead, all subclasses of the `BaseModel` take in an `nn.Module` constructor argument, which is used to implement the architecture of the model. These modules can be built in Flax (e.g. [minimal T5](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/network.py)) or on top of a layers library such as [Flaxformer](https://github.com/google/flaxformer).\n", - "\n", - "We've provided a sample model definition below. For this example, we will instantiate an `EncoderDecoderModel`, which will also require us to define input and output vocabularies, an optimizer, and a decode function. We'll use the minimal T5 module to implement our model architecture." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "FqH_jEcUdEsd" - }, - "outputs": [], - "source": [ - "# Define EncoderDecoderModel constructor args (except the module).\n", - "input_vocabulary=t5.data.get_default_vocabulary()\n", - "output_vocabulary=t5.data.get_default_vocabulary()\n", - "optimizer=t5x.adafactor.Adafactor(decay_rate=0.8, step_offset=0, logical_factor_rules=t5x.adafactor.standard_logical_factor_rules())\n", - "decode_fn=functools.partial(t5x.decoding.temperature_sample, temperature=1.0, topk=40)\n", - "\n", - "# Define a model using the minimal T5 module.\n", - "t5_module = network.Transformer(config=network.T5Config(\n", - " vocab_size=32128,\n", - " dtype='bfloat16',\n", - " emb_dim=512,\n", - " num_heads=6,\n", - " num_encoder_layers=8,\n", - " num_decoder_layers=8,\n", - " head_dim=64,\n", - " mlp_dim=1024,\n", - " mlp_activations=('gelu', 'linear'),\n", - " dropout_rate=0.0,\n", - " logits_via_embedding=False))\n", - "model = t5x.models.EncoderDecoderModel(\n", - " module=t5_module,\n", - " input_vocabulary=input_vocabulary,\n", - " output_vocabulary=output_vocabulary,\n", - " optimizer_def=optimizer,\n", - " decode_fn=decode_fn)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "r74x8nJpfe3G" - }, - "source": [ - "## Checkpoints\n", - "\n", - "The T5X codebase also includes checkpoints for a wide variety of pre-trained T5X models. A full list of all publically available checkpoints is available at https://github.com/google-research/t5x/blob/main/docs/models.md.\n", - "\n", - "For the following example, we have selected a pretrained [T5 1.1 Small model](https://github.com/google-research/t5x/blob/main/docs/models.md) that has been additionally finetuned to answer natural questions using the (open domain) [Natural Questions benchmark](https://ai.google.com/research/NaturalQuestions). We use this finetuned checkpoint for this example in order to see improved performance on the natural question examples we will use for training/inference/evaluation later on.\n", - "\n", - "To restore our model from this checkpoint, we first define the path to our checkpoint and the `dtype` to restore." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "hvVNR5AWgPOC" - }, - "outputs": [], - "source": [ - "# The checkpoint below is a T5-1.1-Small checkpoint (https://github.com/google-research/t5x/blob/main/docs/models.md)\n", - "# that has additionally been finetuned on the (Open Domain) Natural Questions\n", - "# benchmark (https://ai.google.com/research/NaturalQuestions).\n", - "checkpoint_path='gs://t5-data/pretrained_models/cbqa/small_ssm_nq/model.ckpt-1110000'\n", - "dtype='bfloat16'" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Obgf5nfMgdkm" - }, - "source": [ - "We also need to define how we want to restore our model. There are two different restore modes that are available in T5X; for now, we will use \"specific\", which will load the most recent checkpoint in the directory specified by `checkpoint_path`." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "eqfWb6GZhVz-" - }, - "outputs": [], - "source": [ - "restore_mode='specific'" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "PkG5uRfdhnzR" - }, - "source": [ - "Finally, it should be noted that if you are restoring your model from a checkpoint, then the model architecture you defined above must match the model architecture of your checkpoint. For all T5X checkpoints listed at https://github.com/google-research/t5x/blob/main/docs/models.md, you can find the correct architecture for the given checkpoint in its corresponding Gin file." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "a4wzrzzTiRbl" - }, - "source": [ - "## Partitioners\n", - "\n", - "Partitioning is the process of dividing and replicating machine learning model parameters, activations, and data across accelerator devices in order to:\n", - "\n", - "\n", - "* Train and infer from models too large to fit in the memory of a single device\n", - "* Use extremely large batch sizes\n", - "* Train faster\n", - "\n", - "In T5X, partitioning is primarily provided through the [jax.pjit](https://github.com/google/jax/tree/main/jax/experimental/pjit.py) fronted via `PjitPartitioner`. `PjitPartitioner` has three primary constructor arguments:\n", - "* `model_parallel_submesh`\n", - "* `num_partitions`\n", - "* `logical_axis_rules`\n", - "\n", - "The `model_parallel_submesh` and `num_partitions` arugments provide two mutually exclusive methods of specifying the submesh of devices to use for model partitioning. If you specify `num_partitions`, T5X will use this value to generate a default `model_parallel_submesh` that is suitable, but may not be the optimal configuration. If you are interested in optimizing performance, you can try out different submeshes using the `model_parallel_submesh` parameter. For simplicity, we will use `num_partitions` in this Colab.\n", - "\n", - "If you are interested in learning more about partitioning, please take a look at our T5X: Partitioning Deep Dive Colab (Colab status: WIP, link is upcoming).\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "kWEmFrvRkAs6" - }, - "outputs": [], - "source": [ - "partitioner=partitioning.PjitPartitioner(\n", - " num_partitions=1,\n", - " model_parallel_submesh=None)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "8b1ECSqVlCcl" - }, - "source": [ - "# Running Training, Inference, and Evaluation\n", - "\n", - "Now, let's get started running training, inference, and evaluation on natural text inputs. T5X provides an `InteractiveModel` class that we can wrap around our model, checkpoint, and partitioner components, enabling us to run training, inference, and evaluation in one line of code each.\n", - "\n", - "The InteractiveModel requires a couple of additional constructor arguments, namely:\n", - "\n", - "\n", - "1. `batch_size`: the number of examples per batch for training, inference, and evaluation.\n", - "2. `task_feature_lengths`: `task_feature_lengths` is a dictionary mapping the task feature key to the maximum length (int) for that feature. If a feature is longer than this length after preprocessing, the feature will be truncated. May be set to `None` to avoid truncation. \\\n", - "For context, task features are specific to tasks (ex: inputs and targets), and can be mapped to various model-specific features (for example, if we are using a decoder-only model, the concatenation of inputs and targets will be mapped to `decoder_target_tokens`, the model features). This mapping is done by the model's feature converter.\n", - "3. `output_dir`: Path to directory where we will write new model checkpoints.\n", - "4. `input_shapes`: a mapping from key to array shape for each model feature in the global (unsharded) input batch. These input shapes are used to define and initialize the train state. Importantly, these input shapes define the *model features* shape, in contrast to the task features described above.\n", - "\n", - "We define these arguments and an instance of the InteractiveModel below. Importantly, it should be noted that the InteractiveModel handles restoring our model from the provided checkpoint path, so once we instantiate the InteractiveModel, we will be ready to run training, inference, and evaluation. Restoring the model from a checkpoint may take a minute or two.\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "ARvXFgR6l8T5" - }, - "outputs": [], - "source": [ - "batch_size=8\n", - "task_feature_lengths = {'inputs': 38, 'targets': 18}\n", - "output_dir='/tmp/output_dir'\n", - "input_shapes = {\n", - " 'encoder_input_tokens': np.array([8, 38]),\n", - " 'decoder_target_tokens': np.array([8, 18]),\n", - " 'decoder_input_tokens': np.array([8, 18]),\n", - " 'decoder_loss_weights': np.array([8, 18])\n", - "}\n", - "\n", - "interactive_model = InteractiveModel(\n", - " batch_size=batch_size,\n", - " task_feature_lengths=task_feature_lengths,\n", - " output_dir=output_dir,\n", - " partitioner=partitioner,\n", - " model=model,\n", - " dtype=dtype,\n", - " restore_mode=restore_mode,\n", - " checkpoint_path=checkpoint_path,\n", - " input_shapes=input_shapes\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "xOCiFtbyl2wc" - }, - "source": [ - "Next, let's define some examples that we want to use for training/inference/evaluation. These examples should either be a list of inputs, or a list of dictionaries mapping 'target'/'input' keys to corresponding values, as shown below. We will define two sets of examples: one set to be trained on, and one set to run inference/evaluation on.\n", - "\n", - "We are using natural question/answer pairs for our examples. As described in the [T5 paper](https://arxiv.org/abs/1910.10683), we must add a task-specific prefix to our input before we feed it to the model in order to specify what task we should perform on the provided input. For natural questions, we use the \"nq question:\" prefix." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "KouYfpKflw04" - }, - "outputs": [], - "source": [ - "training_examples = [\n", - " {\n", - " 'input': 'nq question: who has been appointed as the new chairman of sebi',\n", - " 'target': 'Ajay Tyagi'\n", - " },\n", - " {\n", - " 'input': 'nq question: who wrote the book lion the witch and the wardrobe',\n", - " 'target': 'C. S. Lewis'\n", - " },\n", - " {\n", - " 'input': 'nq question: how many planes did japan lose at pearl harbor',\n", - " 'target': '29'\n", - " },\n", - " {\n", - " 'input': 'nq question: who does the voice of mcgruff the dog',\n", - " 'target': 'Jack Keil'\n", - " },\n", - " {\n", - " 'input': 'nq question: who sings the wheels in the sky keep on turning',\n", - " 'target': 'Journey'\n", - " },\n", - " {\n", - " 'input': 'nq question: who voices regina in glitter force doki doki',\n", - " 'target': 'Kumiko Watanabe'\n", - " },\n", - " {\n", - " 'input': 'nq question: when did the us become allies with britain',\n", - " 'target': 'during World War II'\n", - " },\n", - " {\n", - " 'input': 'nq question: who won the rugby 7 in las vegas',\n", - " 'target': 'the United States'\n", - " },\n", - "]\n", - "\n", - "validation_examples = [\n", - " {\n", - " 'target': 'Joe Biden',\n", - " 'input':'nq question: who is the president of the united states'\n", - " },\n", - " {\n", - " 'target': 'F. Scott Fitzgerald',\n", - " 'input': 'nq question: who wrote the book the great gatsby'},\n", - " {\n", - " 'target': '1914',\n", - " 'input': 'nq question: in what year did the first world war begin'},\n", - " {\n", - " 'target': 'Idina Menzel',\n", - " 'input': 'nq question: who does the voice of elsa in Frozen'},\n", - " {\n", - " 'target': 'Taylor Swift',\n", - " 'input': 'nq question: who sings shake it off'},\n", - " {\n", - " 'target': 'Tom Kenny',\n", - " 'input': 'nq question: who voices spongebob squarepants'},\n", - " {\n", - " 'target': '2010',\n", - " 'input': 'nq question: when did the great british bake off start'},\n", - " {\n", - " 'target': 'the Philadelphia Eagles',\n", - " 'input': 'nq question: who won the superbowl in 2018'},\n", - "]" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "JAMJOydml0NU" - }, - "source": [ - "Now, we can run training, inference and evaluation on these examples with a single line of code for each task. Below, we run training and inference (evaluation requires a few more arguments, so we go over evaluation in a following section). This may take ~60 seconds." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "8OsCbMb0XlC0" - }, - "outputs": [], - "source": [ - "interactive_model.train_step(examples=training_examples)\n", - "print(f\"Training Summary: {interactive_model.train_summary}\\n\")\n", - "print(f\"Step Number: {interactive_model.step}\\n\")\n", - "\n", - "examples_and_predictions, _ = interactive_model.predict_with_aux(examples=validation_examples)\n", - "predictions = [prediction for example, prediction in examples_and_predictions]\n", - "print(f\"Predictions: {predictions}\\n\")\n", - "\n", - "examples_and_scores = interactive_model.score(examples=validation_examples)\n", - "scores = [score for example, score in examples_and_scores]\n", - "print(f\"Scores: {scores}\\n\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "0SnCvC4d4pjA" - }, - "source": [ - "Alternately, you can run a training/inference/evaluation loop over multiple batches. The training loop below runs training and inference for each step, using the provided batches, and returns the predictions and scores from the final step. This may take ~60 seconds (note: if you use XL or XXL model sizes, this loop may take a while to complete; we are working on improved compilation strategies that optimize for runtime in b/247170488)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "GaBd-e5s4xZp" - }, - "outputs": [], - "source": [ - "second_batch_of_examples = [\n", - " {\n", - " 'input': 'nq question: who won the most academy awards in his lifetime',\n", - " 'target': 'Walt Disney'\n", - " },\n", - " {\n", - " 'input': 'nq question: who starred in the hand that rocks the cradle',\n", - " 'target': 'Rebecca De Mornay'\n", - " },\n", - " {\n", - " 'input': 'nq question: what does a red license plate mean in ontario',\n", - " 'target': 'diplomat'\n", - " },\n", - " {\n", - " 'input': 'nq question: who sang i dreamed a dream on britain\\'s got talent',\n", - " 'target': 'Susan Magdalane Boyle'\n", - " },\n", - " {\n", - " 'input': 'nq question: when is season 7 of game of thrones being released',\n", - " 'target': 'August 27, 2017'\n", - " },\n", - " {\n", - " 'input': 'nq question: when is anne with an e season two coming out',\n", - " 'target': 'in 2018'\n", - " },\n", - " {\n", - " 'input': 'nq question: when was hard rock hotel las vegas built',\n", - " 'target': '1995'\n", - " },\n", - " {\n", - " 'input': 'nq question: what type of reaction leads to the production of polymers',\n", - " 'target': 'condensation reaction'\n", - " }\n", - "]\n", - "all_training_batches = [training_examples, second_batch_of_examples]\n", - "examples_and_predictions, examples_and_scores, _ = interactive_model.train_loop(num_steps=2, train_batches=all_training_batches, predict_batches=[validation_examples], score_batches=[validation_examples])\n", - "\n", - "print(\"\\n All Predictions\")\n", - "for example, prediction in examples_and_predictions:\n", - " print(f\"Input: {example['inputs_pretokenized']}, Prediction: {prediction}\")\n", - "print(\"\\nAll Scores:\")\n", - "for example, score in examples_and_scores:\n", - " print(f\"Input: {example['inputs_pretokenized']}, Score: {score}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "iGaGfR84_8Ap" - }, - "source": [ - "### Preprocessors" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "lpRZd2F4-G8l" - }, - "source": [ - "By default, the only preprocessors that the methods above run are [tokenization](https://github.com/google/seqio/tree/main/seqio/preprocessors.py) and [appending an EOS token](https://github.com/google/seqio/tree/main/seqio/preprocessors.py). If you would like to use different preprocessors, you can do so using the `train_step_with_preprocessors` or `infer_with_preprocessors` methods. We've provided a sample below:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "dDWqZoeW_fRi" - }, - "outputs": [], - "source": [ - "preprocessors = [\n", - " seqio.preprocessors.tokenize,\n", - " seqio.preprocessors.append_eos\n", - "]\n", - "\n", - "interactive_model.train_step_with_preprocessors(examples=training_examples, preprocessors=preprocessors)\n", - "print(f\"Training Summary: {interactive_model.train_summary}\\n\")\n", - "print(f\"Step Number: {interactive_model.step}\\n\")\n", - "\n", - "# Note: when we use a custom list of preprocessors, we must use a general\n", - "# `infer` method, rather than `predict` or `score`. Thus, we must also specify\n", - "# the type of inference to do; valid options are `PREDICT_WITH_AUX`,\n", - "# or `SCORE`.\n", - "examples_and_predictions, _ = interactive_model.infer_with_preprocessors(\n", - " mode=InferenceType.PREDICT_WITH_AUX,\n", - " examples=validation_examples,\n", - " preprocessors=preprocessors)\n", - "predictions = [prediction for example, prediction in examples_and_predictions]\n", - "print(f\"Predictions: {predictions}\\n\")\n", - "\n", - "examples_and_scores, _ = interactive_model.infer_with_preprocessors(\n", - " mode=InferenceType.SCORE,\n", - " examples=validation_examples,\n", - " preprocessors=preprocessors)\n", - "scores = [score for example, score in examples_and_scores]\n", - "print(f\"Scores: {scores}\\n\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Sa08OJEuAuhv" - }, - "source": [ - "Because we use the same set of preprocessors, we should expect to see the same results as before.\n", - "\n", - "If you are interested in learning more about preprocessors, please see [this preprocessors guide](https://github.com/google/seqio/blob/main/README.md#preprocessors), which also contains links to implementations of common preprocessors." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "PNbOuE-kA21g" - }, - "source": [ - "### Evaluation and Metrics Functions" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "EXw52bCt8s3A" - }, - "source": [ - "We can similarly run evaluation in a single line. Running evaluation requires that we specify a metric function and (optionally) a list of postprocessors to run on the data before we compute metrics.\n", - "\n", - "There are a variety of sample metrics defined in [t5/evaluation/metrics.py](https://github.com/google-research/text-to-text-transfer-transformer/tree/main/t5/evaluation/metrics.py). For this example, we will use the [SQuAD metric function](https://github.com/google-research/text-to-text-transfer-transformer/tree/main/t5/evaluation/metrics.py) defined in this file. Because we are using natural questions, we will also specify a postprocessor to correctly format question and answer pairs for metrics calculations; specifically, we will use a standard QA postprocessor, modeled after the [`t5.data.postprocessors.qa` method](https://github.com/google-research/text-to-text-transfer-transformer/tree/main/t5/data/postprocessors.py). We will continue to use the same preprocessors." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "_xr9iOEbxSzS" - }, - "outputs": [], - "source": [ - "def qa(answer, example=None, is_target=False):\n", - " \"\"\"Returns answer, or all answers if the full example is provided.\"\"\"\n", - " if is_target:\n", - " return [tf.compat.as_text(a) for a in [example[\"targets_pretokenized\"]]]\n", - " return answer" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "XLL9_3OwCU9i" - }, - "outputs": [], - "source": [ - "metrics = interactive_model.evaluate_with_preprocessors(\n", - " examples=validation_examples,\n", - " preprocessors=preprocessors,\n", - " metric_fns=[t5_metrics.squad],\n", - " postprocessor=qa)\n", - "print(f\"Metrics: {metrics}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "OJ5NHdcoPFVX" - }, - "source": [ - "If you are interested in learning more about metrics functions or postprocessors, please see this [Metrics guide](https://github.com/google/seqio/blob/main/README.md#metrics ) and/or this [Postprocessors guide](https://github.com/google/seqio/blob/main/README.md#postprocessor)." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "lcDwmp_AxnOG" - }, - "source": [ - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "-QR5LnmN4ikp" - }, - "source": [ - "# Advanced Topics" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "wi29fMdv4mSr" - }, - "source": [ - "## SeqIO\n", - "\n", - "If you are interested in T5X, you may also be interested in, or have heard of, SeqIO. SeqIO is a library for processing sequential data to be fed into downstream sequence models. At a high level, SeqIO relies on user-defined `Tasks` and `Mixtures` that can be used to retrieve and evaluate datasets.\n", - "\n", - "We won't go into details about SeqIO here; we recommend checking out this [SeqIO Introductory guide](https://github.com/google/seqio/blob/main/README.md) and/or clicking below to run a SeqIO Introductory Colab. The rest of this section will assume a basic understanding of SeqIO.\n", - "\n", - "\u003ca href=\"https://colab.research.google.com/github/google/seqio/blob/main/seqio/notebooks/Basics_Task_and_Mixtures.ipynb\" target=\"_parent\"\u003e\u003cimg src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/\u003e\u003c/a\u003e\n", - "\n", - "If you are already familiar with SeqIO and have a SeqIO task/mixture that you would like to use in this Colab, we do provide a SeqIO bridge that takes in a SeqIO task/mixture and produces batches of examples that can be processed by the InteractiveModel above. We've provided an example of this bridge below." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "Hxu9mRL5yBGK" - }, - "outputs": [], - "source": [ - "!git clone https://github.com/google-research/google-research.git google_research" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "bM0nRIEFwyj_" - }, - "outputs": [], - "source": [ - "import google_research.t5_closed_book_qa.t5_cbqa.tasks\n", - "batches = get_batches_from_seqio(\n", - " task_or_mixture_name='natural_questions_open',\n", - " split='validation',\n", - " batch_size=8,\n", - " num_batches=2,\n", - " seed=42)\n", - "print(f\"Batches: {batches}\")\n", - "# Train the interactive model on the provided batches.\n", - "original_step = interactive_model.step\n", - "_ = interactive_model.train_loop(num_steps=len(batches), train_batches=batches)\n", - "print(f\"Original Step: {original_step}, Current Step: {interactive_model.step}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Elt08160w03X" - }, - "source": [ - "The `get_batches_from_seqio` bridge can take several constructor arguments:\n", - "\n", - "\n", - "1. `task_or_mixture_name`: the name of the SeqIO task/mixture to read data from. It should be noted that your task/mixture must already be registered with SeqIO, and you must import the module that defines your task/mixture here (as seen above).\n", - "2. `split`: the split of the Task/Mixture to read data from.\n", - "3. `batch_size`: how many examples should appear in each batch.\n", - "4. `num_batches`: the total number of batches to return.\n", - "5. `get_pretokenized_examples`: optional. A boolean, defaulting to True, that determines whether we should read the `inputs_pretokenized`/`targets_pretokenized` elements from an example, or the `inputs`/`targets` elements. \\\n", - "The `train_step`, `predict`, `predict_with_aux`, `score`, and `evaluate` methods of the InteractiveModel assume that we should run [tokenization](https://github.com/google/seqio/tree/main/seqio/preprocessors.py) and [appending an EOS token](https://github.com/google/seqio/tree/main/seqio/preprocessors.py) as the only preprocessors. To use these methods with this pre-defined list of preprocessors, you can set `get_pretokenized_examples=True` to retrieve examples that still need to be tokenized, and these InteractiveModel methods will handle running these preprocessors. This setting can also be helpful if you want to inspect the natural text inputs/targets of your SeqIO task. \\\n", - "However, some SeqIO tasks do not use tokenization (ex: span corruption). You can set `get_pretokenized_examples=False`, and this bridge will read the fully preprocessed examples from the SeqIO task. You can then run `train_step_with_preprocessors`, `infer_with_preprocessors`, or `evaluate_with_preprocessors` and provide an empty preprocessors list (because all preprocessing has already been completed by this bridge) to run training/inference/evaluation. We have provided an example of using this bridge to retrieve fully preprocessed examples below.\n", - "6. `sequence_length`: optional. A dictionary mapping feature key to maximum length (int) for that feature. Used by SeqIO to retrieve the dataset/examples.\n", - "7. `**get_dataset_kwargs`: there are many [additional parameters](https://github.com/google/seqio/tree/main/seqio/dataset_providers.py) that can be set in the `SeqIO.get_dataset` function. If you would like to set any of these arguments, you can set them using this `kwargs` parameter.\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "fjKBCX39w0Xl" - }, - "outputs": [], - "source": [ - "import t5.data.tasks\n", - "batches = get_batches_from_seqio(\n", - " task_or_mixture_name='c4_v220_span_corruption',\n", - " split='validation',\n", - " batch_size=8,\n", - " num_batches=1,\n", - " get_pretokenized_examples=False,\n", - " sequence_length=interactive_model._task_feature_lengths,\n", - " seed=42)\n", - "batch = batches[0] # We expect only a single batch.\n", - "original_step = interactive_model.step\n", - "interactive_model.train_step_with_preprocessors(\n", - " examples=batch, preprocessors=[])\n", - "print(f\"Original Step: {original_step}, Current Step: {interactive_model.step}\")" - ] - } - ], - "metadata": { - "colab": { - "collapsed_sections": [ - "RE_CQr9Hcr1D", - "r74x8nJpfe3G", - "a4wzrzzTiRbl" - ], - "last_runtime": { - "build_target": "//learning/grp/tools/ml_python:ml_notebook", - "kind": "private" - }, - "name": "Welcome to T5X: An Introductory Colab", - "private_outputs": true, - "provenance": [ - { - "file_id": "1Akpc6pKlJB5rn5YYYFC9lw2OMk6oBzlQ", - "timestamp": 1662951101563 - }, - { - "file_id": "1rA8bgO2bJRoebAuS96Ji0RUhnawgBY4i", - "timestamp": 1650477076639 - } - ] - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "name": "python" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} diff --git a/t5x-main/t5x/notebooks/training.ipynb b/t5x-main/t5x/notebooks/training.ipynb deleted file mode 100644 index 1d0436cb82bce67dc4f8a2e72067372b9442e3e8..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/notebooks/training.ipynb +++ /dev/null @@ -1,971 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "-0BQWhvAP2jb" - }, - "source": [ - "\n", - "\u003ca href=\"https://colab.research.google.com/github/google-research/t5x/blob/main/t5x/notebooks/training.ipynb\" target=\"_parent\"\u003e\u003cimg src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/\u003e\u003c/a\u003e" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "bqZYp90PIa1t" - }, - "source": [ - "# Overview\n", - "\n", - "This is the second Colab in a [series of tutorials on how to use T5X](https://github.com/google-research/t5x/blob/main/docs/tutorials.md). We assume that you have already completed https://github.com/google-research/text-to-text-transfer-transformer/blob/main/README.mdx-colab-intro, or have a basic understanding of the T5X models, checkpoints, partitioner, and `InteractiveModel`.\n", - "\n", - "In the [previous Colab](https://colab.research.google.com/github/google-research/t5x/blob/main/t5x/notebooks/introduction.ipynb) in this tutorial series, we presented a quick and easy way to use the `InteractiveModel` to run training on natural text inputs in only a few lines of code. In this Colab, we will dive into how the `InteractiveModel` restores models from checkpoints and runs training, while also getting an introduction to the T5X trainer. It should be noted that the code snippets below exactly replicate the InteractiveModel `__init__()` and `train_step()` methods (see [source code](https://github.com/google-research/t5x/blob/main/t5x/interactive_model.py)); we expose this functionality here in order to demonstrate how various components of the T5X codebase work together to train a model." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "nZJbWZcfkyxI" - }, - "source": [ - "# Set-Up" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "rn4C8OdjVn6e" - }, - "source": [ - "Note: If you are a using public colab, please use its `Connect to a local runtime` option by following the [setup guide](https://github.com/google-research/t5x/blob/main/t5x/notebooks/README.md)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "jIGSIHzD7YPO" - }, - "outputs": [], - "source": [ - "from collections.abc import Sequence\n", - "import enum\n", - "import functools\n", - "import inspect\n", - "import itertools\n", - "import logging\n", - "import os\n", - "import re\n", - "from typing import Any, Callable, Iterator, Optional, Tuple, Union\n", - "\n", - "import jax\n", - "from jax import random\n", - "from jax.experimental import multihost_utils\n", - "import numpy as np\n", - "import seqio\n", - "import tensorflow as tf\n", - "import tensorflow_datasets as tfds\n", - "import t5.data" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "xAIF2kbAVw-0" - }, - "outputs": [], - "source": [ - "import clu.data\n", - "from t5x.examples.t5 import network\n", - "import t5x\n", - "from t5x import models\n", - "from t5x import partitioning\n", - "from t5x import trainer as trainer_lib\n", - "from t5x import utils\n", - "from t5x.infer import _extract_tokens_and_aux_values\n", - "from t5x.infer import _Inferences\n", - "from t5x.interactive_model import InteractiveModel\n", - "from t5x.interactive_model import get_batches_from_seqio\n", - "from t5x.interactive_model import get_dataset_from_natural_text_examples\n", - "from t5x.interactive_model import get_gin_config_from_interactive_model\n", - "from t5x.interactive_model import T5XScriptType" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "S5Lb-Z1fkF5a" - }, - "source": [ - "Before we begin, let's initialize instances of the constructor arguments for the `InteractiveModel`. As mentioned previously, this will enable us to dive into how the `InteractiveModel` restores models from checkpoints and runs training.\n", - "\n", - "If you don't understand the lines of code below, or have questions about how to initialize these parameters, please see the [first Colab](https://colab.research.google.com/github/google-research/t5x/blob/main/t5x/notebooks/introduction.ipynb) in this tutorial series." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "Ne8U8qoWkX_r" - }, - "outputs": [], - "source": [ - "# Define a model. The configuration below corresponds to the T5 1.1 Small model.\n", - "t5_config = network.T5Config(\n", - " vocab_size=32128,\n", - " dtype='bfloat16',\n", - " emb_dim=512,\n", - " num_heads=6,\n", - " num_encoder_layers=8,\n", - " num_decoder_layers=8,\n", - " head_dim=64,\n", - " mlp_dim=1024,\n", - " mlp_activations=('gelu', 'linear'),\n", - " dropout_rate=0.0,\n", - " logits_via_embedding=False)\n", - "module = network.Transformer(config=t5_config)\n", - "model = t5x.models.EncoderDecoderModel(\n", - " module=module,\n", - " input_vocabulary=t5.data.get_default_vocabulary(),\n", - " output_vocabulary=t5.data.get_default_vocabulary(),\n", - " optimizer_def=t5x.adafactor.Adafactor(decay_rate=0.8, step_offset=0))\n", - "# Define checkpoint arguments.\n", - "checkpoint_path='gs://t5-data/pretrained_models/cbqa/small_ssm_nq/model.ckpt-1110000'\n", - "dtype='bfloat16'\n", - "restore_mode='specific'\n", - "# Define a partitioner.\n", - "partitioner=partitioning.PjitPartitioner(num_partitions=2)\n", - "# Define additional, miscellaneous constructor arguments.\n", - "batch_size=8\n", - "task_feature_lengths = {'inputs': 38, 'targets': 18}\n", - "output_dir='/tmp/output_dir'\n", - "input_shapes = {\n", - " 'encoder_input_tokens': np.array([8, 38]),\n", - " 'decoder_target_tokens': np.array([8, 18]),\n", - " 'decoder_input_tokens': np.array([8, 18]),\n", - " 'decoder_loss_weights': np.array([8, 18])\n", - "}" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ib9aOi2xaCKQ" - }, - "source": [ - "# Training Deep Dive" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "-trKAuOWaGoB" - }, - "source": [ - "Let's start by going over what happens when we initialize the InteractiveModel.\n", - "\n", - "The `InteractiveModel` `__init__()` method performs six main actions:\n", - "\n", - "\n", - "1. Configure and possibly create an output directory.\n", - "2. Initialize RNGs.\n", - "3. Validate the partitioner.\n", - "4. Create a checkpoint manager.\n", - "5. Restore the model from a checkpoint or initialize from scratch.\n", - "6. Create a trainer.\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "zMPQQxR6lW78" - }, - "source": [ - "**Configuring the Output Directory** \\\n", - "There is minimal work required to configure the output directory for our model: we simply remove double-slashes in the directory path to avoid inconsistencies and create the directory if it doesn't already exist." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "3wyAW8drm1JV" - }, - "outputs": [], - "source": [ - "output_dir = re.sub(r\"(?\u003c!gs:)([\\/]{2,})\", \"/\", output_dir)\n", - "if not os.path.exists(output_dir):\n", - " os.mkdir(output_dir)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "5wGHXxSynEs4" - }, - "source": [ - "**Initializing RNGs** \\\n", - "Initializing RNGs is made fairly straightforward with the use of JAX random operations.\n", - "\n", - "\n", - "\n", - "We first set an initial seed using the `multihost_utils` tools, then define an RNG using the JAX `PRNGKey` utils, and finally split this RNG into two values: one each for initializing the model and training the model. This ensures that we never reuse an RNG key." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "wsOqpNXUnl81" - }, - "outputs": [], - "source": [ - "init_random_seed = 42\n", - "random_seed = multihost_utils.broadcast_one_to_all(np.int32(init_random_seed))\n", - "utils.set_hardware_rng_ops()\n", - "rng = random.PRNGKey(random_seed)\n", - "init_rng, trainer_rng = random.split(rng, 2)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "uiw55rQ_n-KP" - }, - "source": [ - "**Validating the Partitioner** \\\n", - "\n", - "Because we've already constructed the partitioner, we simply need to validate that it was constructed properly. In particular, we need to ensure that the number of partitions created by the partitioner can easily divide the total number of devices." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "HK2SfW9RoQjE" - }, - "outputs": [], - "source": [ - "if partitioner._model_parallel_submesh:\n", - " num_partitions = np.prod(partitioner._model_parallel_submesh)\n", - "else:\n", - " num_partitions = partitioner._num_partitions\n", - "if jax.device_count() % num_partitions != 0:\n", - " raise ValueError(\n", - " \"The number of devices available must be a multiple of the number of\",\n", - " f\" partitions. There are {jax.device_count()} devices available, but\",\n", - " f\" the number of partitions is set to {num_partitions}. Please\",\n", - " \" provide a different number of partitions.\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "C0sjeDUnoYz7" - }, - "source": [ - "**Create a Checkpoint Manager**\n", - "\n", - "We make use of the T5X [`LegacyCheckpointManager`](https://github.com/google-research/t5x/blob/main/t5x/utils.py) to restore our model and save any future checkpoints. The `LegacyCheckpointManager` requires several constructor arguments:\n", - "\n", - "\n", - "\n", - "1. `save_checkpoint_cfg`: an instance of the `SaveCheckpointConfig` wrapper class, which contains information about where and how to save future checkpoints.\n", - "2. `restore_checkpoint_cfg`: an instance of the `RestoreCheckpointConfig` wrapper class, which contains information and where and how to load checkpoints and restore the model.\n", - "3. `train_state_shape`: our model will load and save a T5X [`TrainState`](https://github.com/google-research/t5x/blob/main/t5x/train_state.py), which (as the name implies) stores information about the current state of training. We provide information about the shape of this train state to the checkpoint manager to enable saving this train state in checkpoints.\n", - "4. `partitioner`: our predefined partitioner.\n", - "5. `model_dir`: our previously configured output directory, where we will save any future checkpoints.\n", - "\n", - "Before we define these constructor arguments and initialize the checkpoint manager, let's discuss the T5X `TrainState` in a bit more depth. Importantly, T5X is a JAX-based library, which means that all of our methods follow typical functional programming patterns.\n", - "\n", - "\n", - "Specifically, our training methods cannot have side effects, so we pass all model parameters, step number, optimizer state, etc. as input and get updated values as output from our methods. We use the T5X `TrainState` to hold all our model parameters, step number, optimizer state, etc. and we will later define a `train_step` method that will take in the train state and return an updated train state with new values.\n", - "\n", - "We define these constructor arguments and initialize the checkpoint manager below.\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "nC18r9Uaq2wL" - }, - "outputs": [], - "source": [ - "# Define CheckpointCfg wrappers.\n", - "save_checkpoint_cfg = utils.SaveCheckpointConfig(\n", - " dtype=dtype,\n", - " keep=5, # The number of checkpoints to keep in the output_dir.\n", - " save_dataset=False)\n", - "restore_checkpoint_cfg = utils.RestoreCheckpointConfig(\n", - " dtype=dtype,\n", - " mode=restore_mode,\n", - " path=checkpoint_path)\n", - "\n", - "# Define a train state initializer, which will help us get information about the\n", - "# TrainState shape.\n", - "train_state_initializer = utils.TrainStateInitializer(\n", - " optimizer_def=model.optimizer_def,\n", - " init_fn=model.get_initial_variables,\n", - " input_shapes=input_shapes,\n", - " input_types=None,\n", - " partitioner=partitioner)\n", - "\n", - "checkpoint_manager = utils.LegacyCheckpointManager(\n", - " save_cfg=save_checkpoint_cfg,\n", - " restore_cfg=restore_checkpoint_cfg,\n", - " train_state_shape=train_state_initializer.global_train_state_shape,\n", - " partitioner=partitioner,\n", - " ds_iter=None,\n", - " model_dir=output_dir)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "LSSfevHCsWX6" - }, - "source": [ - "**Restore the Model from a Checkpoint or Initialize from Scratch** \\\n", - "\n", - "We try two different strategies for restoring a model. First, we try to restore the model from a checkpoint using the `CheckpointManager`. If no checkpoint can be found (likely because no path was provided in `checkpoint_path`), then we will initialize the model from scratch.\n", - "\n", - "Finally, we will log model initialization information (such as parameter shape, partitioning annotation, etc.) to the output directory." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "-PgHNNOfuuFj" - }, - "outputs": [], - "source": [ - "def get_state(rng):\n", - " return train_state_initializer.from_scratch(rng).state_dict()\n", - "\n", - "# 1. Try to restore a model from a checkpoint.\n", - "train_state = checkpoint_manager.restore(\n", - " [restore_checkpoint_cfg.path],\n", - " restore_checkpoint_cfg,\n", - " utils.get_fallback_state(restore_checkpoint_cfg, get_state, init_rng)\n", - ")\n", - "\n", - "# 2. If no checkpoint to restore, init from scratch.\n", - "if train_state is None:\n", - " train_state = train_state_initializer.from_scratch(init_rng)\n", - "\n", - "# Validate that we got an expected form of TrainState.\n", - "if isinstance(train_state, Sequence):\n", - " raise ValueError(\n", - " \"Expected a single train state, but instead received a Sequence.\")\n", - "train_state_axes = train_state_initializer.train_state_axes\n", - "\n", - "# Log the variable shapes information and write to a file.\n", - "log_file = os.path.join(output_dir, \"model-info.txt\")\n", - "utils.log_model_info(log_file,\n", - " train_state_initializer.global_train_state_shape,\n", - " partitioner)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "7JL5Nh6RvVbC" - }, - "source": [ - "**Create a Trainer**\n", - "\n", - "Finally, we use many of the parameters we've defined above to create an instance of the T5X [Trainer](https://github.com/google-research/t5x/blob/main/t5x/trainer.py). The trainer takes in several constructor arguments:\n", - "\n", - "\n", - "\n", - "1. `model`: the model that will be trained\n", - "2. `train_state`: a train state with parameters and optimizer state, which we've restored or initialized above.\n", - "3. `partitioner`: the partitioner to use.\n", - "4. `summary_dir`: the output directory, where we can write summaries of training.\n", - "5. `train_state_axes`: partitioning information for the optimizer, which we've initialized above.\n", - "6. `rng`: the JAX RNG to be used for training.\n", - "7. `learning_rate_fn`: a function that returns the learning rate given the current step. T5X provides some helper functions that define common learning rate schedules; we will use one of these helpers to define the learning rate in our example.\n", - "\n", - "We initialize a sample Trainer below.\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "pmVqBHrny2p3" - }, - "outputs": [], - "source": [ - "trainer = trainer_lib.Trainer(\n", - " model=model,\n", - " train_state=train_state,\n", - " partitioner=partitioner,\n", - " eval_names=[],\n", - " summary_dir=output_dir,\n", - " train_state_axes=train_state_axes,\n", - " rng=trainer_rng,\n", - " learning_rate_fn=utils.create_learning_rate_scheduler(),\n", - " num_microbatches=None)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "5_06UyFrlXFT" - }, - "source": [ - "The code snippets above exactly replicate the `InteractiveModel` `__init__()` method (see [source code](https://github.com/google-research/t5x/blob/main/t5x/interactive_model.py)); running the code snippets above is exactly equivalent to running the single code snippet below." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "Q_VA6DUi0MFv" - }, - "outputs": [], - "source": [ - "interactive_model = InteractiveModel(\n", - " batch_size=batch_size,\n", - " task_feature_lengths=task_feature_lengths,\n", - " output_dir=output_dir,\n", - " partitioner=partitioner,\n", - " model=model,\n", - " dtype=dtype,\n", - " restore_mode=restore_mode,\n", - " checkpoint_path=checkpoint_path,\n", - " input_shapes=input_shapes\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ANqpfv0lAVqL" - }, - "source": [ - "**Defining a Batch of Examples to Train On**\\\n", - "We are now ready to begin training!\n", - "\n", - "First, we'll begin by defining a batch of examples to train on; these examples should either be a list of inputs, or a list of dictionaries mapping 'target'/'input' keys to corresponding values, as shown below. For this Colab, we'll use a set of natural test questions and answers." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "yhhR0yDcAn7w" - }, - "outputs": [], - "source": [ - "examples = [\n", - " {\n", - " 'target': b'Ajay Tyagi',\n", - " 'input':b'nq question: who has been appointed as the new chairman of sebi'\n", - " },\n", - " {\n", - " 'target': b'C. S. Lewis',\n", - " 'input': b'nq question: who wrote the book lion the witch and the wardrobe'},\n", - " {\n", - " 'target': b'29',\n", - " 'input': b'nq question: how many planes did japan lose at pearl harbor'},\n", - " {\n", - " 'target': b'Jack Keil',\n", - " 'input': b'nq question: who does the voice of mcgruff the dog'},\n", - " {\n", - " 'target': b'Journey',\n", - " 'input': b'nq question: who sings the wheels in the sky keep on turning'},\n", - " {\n", - " 'target': b'Kumiko Watanabe',\n", - " 'input': b'nq question: who voices regina in glitter force doki doki'},\n", - " {\n", - " 'target': b'during World War II',\n", - " 'input': b'nq question: when did the us become allies with britain'},\n", - " {\n", - " 'target': b'the United States',\n", - " 'input': b'nq question: who won the rugby 7 in las vegas'},\n", - "]" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "jLc-P9S6_uOD" - }, - "source": [ - "We also define the required features of the examples. For this Colab, we will only require an `inputs` and `targets` entry, as defined below." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "i5n6bnSq_2jF" - }, - "outputs": [], - "source": [ - "output_features = {\n", - " \"inputs\":\n", - " seqio.Feature(\n", - " vocabulary=model.input_vocabulary, add_eos=True),\n", - " \"targets\":\n", - " seqio.Feature(\n", - " vocabulary=model.output_vocabulary, add_eos=True)\n", - " }\n", - "features = dict(sorted(output_features.items()))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "mixLzcBkQOT_" - }, - "source": [ - "Now, let's (similarly) break down what the interactive model does when it takes a single step of training.\n", - "\n", - "The `InteractiveModel` `train_step()` method only performs two actions:\n", - "\n", - "\n", - "1. Convert the natural text examples into a tf.Dataset.\n", - "2. Take a single training step, using the T5X Trainer.\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ug0zJx2kQk6g" - }, - "source": [ - "**Prepare the dataset** \\\n", - "\n", - "Preparing the data for training is fairly straightforward. First, we validate that enough examples have been provided to train on a full batch of data.\n", - "\n", - "Then, we convert the natural text examples into a tf.Dataset and run any preprocessors; T5X has a helper function, `get_dataset_from_natural_text_examples`, that can do exactly that. For this example, the only preprocessing we will do is tokenization and appending an EOS token. If you are interested in learning more about preprocessors, please take a look at the [first Colab](https://colab.research.google.com/github/google-research/t5x/blob/main/t5x/notebooks/introduction.ipynb) in this tutorial series.\n", - "\n", - "Finally, we convert all features using the model's feature converter and pad all batches of data." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "chPomDFxQ6r3" - }, - "outputs": [], - "source": [ - "# Validate num examples.\n", - "if len(examples) \u003c batch_size:\n", - " raise ValueError(\n", - " \"At least one batch of data must be provided. Please decrease the \"\n", - " \"batch_size or provide more examples.\")\n", - "# Get a tf.Dataset.\n", - "train_dataset = get_dataset_from_natural_text_examples(\n", - " examples=examples,\n", - " preprocessors=[\n", - " seqio.preprocessors.tokenize,\n", - " seqio.preprocessors.append_eos\n", - " ],\n", - " task_feature_lengths=task_feature_lengths,\n", - " features=features)\n", - "\n", - "# Convert and pad features.\n", - "feature_converter = model.FEATURE_CONVERTER_CLS(pack=False)\n", - "train_dataset = feature_converter(\n", - " train_dataset, task_feature_lengths=task_feature_lengths)\n", - "train_dataset = train_dataset.padded_batch(batch_size, drop_remainder=True)\n", - "train_iter = clu.data.dataset_iterator.TfDatasetIterator(train_dataset, checkpoint=False)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "toq2uz7dAfdL" - }, - "source": [ - "**Run 1 Training Step** \\\n", - "\n", - "We'll define a helper function that takes a single train step, making it easy to loop over this helper and train for multiple steps.\n", - "\n", - "Training is made fairly straightforward because of the T5X trainer. We'll simply add some logic to validate that it's ok for training to occur and to save a checkpoint. In total, we'll perform the following actions:\n", - "\n", - "\n", - "1. Validate that training can occur.\n", - "2. Take a training step.\n", - "3. Save a checkpoint.\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "Qv8RXHeXK6mk" - }, - "outputs": [], - "source": [ - "def train_step(\n", - " trainer: t5x.trainer.Trainer,\n", - " train_state: t5x.train_state.TrainState,\n", - " train_iter: clu.data.dataset_iterator.TfDatasetIterator,\n", - " checkpoint_manager: utils.LegacyCheckpointManager,\n", - " save_checkpoint_cfg: utils.SaveCheckpointConfig):\n", - " # Validate that training can occur.\n", - " if trainer.stop_training:\n", - " logging.info(\"Stopping training early since `stop_training` is requested.\")\n", - " return\n", - "\n", - " # Take a training step.\n", - " try:\n", - " first_step = int(utils.get_local_data(train_state.step))\n", - " train_summary = trainer.train(\n", - " train_iter, 1, start_step=first_step)\n", - " except trainer_lib.PreemptionError as e:\n", - " logging.info(\"Saving emergency checkpoint.\")\n", - " checkpoint_manager.save(\n", - " trainer.train_state,\n", - " save_checkpoint_cfg.state_transformation_fns)\n", - " logging.info(\"Saving emergency checkpoint done.\")\n", - " raise e\n", - "\n", - " # Save a checkpoint.\n", - " logging.info(\"Saving checkpoint.\")\n", - " checkpoint_manager.save(\n", - " trainer.train_state,\n", - " save_checkpoint_cfg.state_transformation_fns)\n", - "\n", - " # Wait until computations are done before exiting\n", - " multihost_utils.sync_global_devices(\"complete\")\n", - " return trainer.train_state, train_summary.result()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "srgktjMKMSZt" - }, - "outputs": [], - "source": [ - "print(f\"Current Step: {train_state.step}\")\n", - "train_state, train_summary = train_step(trainer, train_state, train_iter, checkpoint_manager, save_checkpoint_cfg)\n", - "print(f\"Current Step: {train_state.step}\")\n", - "print(f\"Summary of Training: {train_summary}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "AB0U_kfRNIyR" - }, - "source": [ - "The code snippets above exactly replicate the `InteractiveModel` `train_step()` method (see [source code](https://github.com/google-research/t5x/blob/main/t5x/interactive_model.py)); running the code snippets above is exactly equivalent to running `interactive_model.train_step(examples)`.\n", - "\n", - "Alternately, you can loop over this helper function multiple times to finetune or pretrain a model (the code snippet below may take ~5 mins to run)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "_7GJbHgNNRtl" - }, - "outputs": [], - "source": [ - "num_steps = 100\n", - "for _ in range(num_steps):\n", - " # Reset the iterator, since we use the same batch for every step.\n", - " train_iter = clu.data.dataset_iterator.TfDatasetIterator(train_dataset, checkpoint=False)\n", - " train_state, train_summary = train_step(\n", - " trainer,\n", - " train_state,\n", - " train_iter,\n", - " checkpoint_manager,\n", - " save_checkpoint_cfg\n", - " )\n", - "print(f\"Current Step: {train_state.step}\")\n", - "print(f\"Summary of Training: {train_summary}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "gs3BwU5sKRmu" - }, - "source": [ - "The code snippets above demonstrate how T5X runs training. You can exactly replicate this behavior by using the `InteractiveModel`, as described above." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "lcDwmp_AxnOG" - }, - "source": [ - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "-QR5LnmN4ikp" - }, - "source": [ - "# Advanced Topics" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "CLstCKpP8Ge7" - }, - "source": [ - "## T5X Training Binaries and Other Advanced Features" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "h9RT_9R_8K3U" - }, - "source": [ - "T5X offers training binaries that have the same functionality as the InteractiveModel, with additional features as well (more advanced compiling, custom checkpointing periods, etc.). Importantly, these binaries are configured using [Gin](https://github.com/google/gin-config/blob/main/README.md); if you are not familiar with Gin, please take a look at this [Gin Primer](https://github.com/google-research/t5x/blob/main/docs/usage.md/gin) to get started.\n", - "\n", - "If you are familiar with Gin and interested in using the T5X training binaries, we have provided a helper function, `get_gin_config_from_interactive_model`, which will take an InteractiveModel instance and generate the gin config that you can use to run the T5X training binaries; this gin config will exactly reproduce the InteractiveModel training functionality we've described above. We've provided an example below.\n", - "\n", - "Importantly, the InteractiveModel takes in a model, partitioner, and data, so we cannot generate Gin configs for these components. You can pass Gin config strings for the model and partitioner components to the helper function, as demonstrated below. Additionally, you can pass a SeqIO task containing your data to the helper function. See the section below if you are unfamiliar with SeqIO." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "rhgUZ0w6yQsE" - }, - "outputs": [], - "source": [ - "# Define an InteractiveModel instance, based on the `tiny` T5X EncoderDecoder model.\n", - "input_shapes = {\n", - " 'encoder_input_tokens': np.array([8, 38]),\n", - " 'decoder_target_tokens': np.array([8, 18]),\n", - " 'decoder_input_tokens': np.array([8, 18]),\n", - " 'decoder_loss_weights': np.array([8, 18])\n", - "}\n", - "t5_config = network.T5Config(\n", - " vocab_size=32128,\n", - " dtype='bfloat16',\n", - " emb_dim=8,\n", - " num_heads=4,\n", - " num_encoder_layers=2,\n", - " num_decoder_layers=2,\n", - " head_dim=3,\n", - " mlp_dim=16,\n", - " mlp_activations=('gelu', 'linear'),\n", - " dropout_rate=0.0,\n", - " logits_via_embedding=False)\n", - "module = network.Transformer(config=t5_config)\n", - "model = t5x.models.EncoderDecoderModel(\n", - " module=module,\n", - " input_vocabulary=t5.data.get_default_vocabulary(),\n", - " output_vocabulary=t5.data.get_default_vocabulary(),\n", - " optimizer_def=t5x.adafactor.Adafactor(decay_rate=0.8, step_offset=0),\n", - " decode_fn=functools.partial(\n", - " t5x.decoding.temperature_sample, temperature=1.0, topk=40))\n", - "interactive_model = InteractiveModel(\n", - " batch_size=8,\n", - " task_feature_lengths={\n", - " 'inputs': 32,\n", - " 'targets': 32\n", - " },\n", - " output_dir='/tmp',\n", - " partitioner=partitioning.PjitPartitioner(\n", - " num_partitions=2,\n", - " model_parallel_submesh=None,\n", - " logical_axis_rules=partitioning.standard_logical_axis_rules()),\n", - " model=model,\n", - " dtype='float32',\n", - " restore_mode='specific',\n", - " checkpoint_path='',\n", - " input_shapes=input_shapes,\n", - " input_types=None)\n", - "\n", - "# Define Gin Config strings for the model, partitioner, and any imports.\n", - "imports_str = \"\"\"from t5x import models\n", - "from t5x import partitioning\n", - "import t5.data.mixtures\n", - "include 't5x/examples/t5/t5_1_1/tiny.gin'\"\"\"\n", - "partitioner_config = 'partitioning.PjitPartitioner.num_partitions = 2'\n", - "model_config = \"\"\"models.EncoderDecoderModel:\n", - " z_loss = 0.0\n", - " label_smoothing = 0.0\n", - " loss_normalizing_factor = None\"\"\"\n", - "\n", - "gin_config_str = get_gin_config_from_interactive_model(\n", - " interactive_model=interactive_model,\n", - " script_type=T5XScriptType.PRETRAINING,\n", - " task_name='wmt19_ende_v003',\n", - " partitioner_config_str=partitioner_config,\n", - " model_config_str=model_config,\n", - " train_steps=3,\n", - " imports_str=imports_str)\n", - "print(gin_config_str)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "uGd1DxDT3gB7" - }, - "source": [ - "\n", - "Once you have generated the `gin_config_str` as above, you can write this string to a file and launch your training experiment locally by running the following on commandline:\n", - "\n", - "\n", - "```\n", - "MODEL_DIR=\"/tmp/pretrain-model/\"\n", - "python -m t5x.train \\\n", - " --gin_file=${GIN_FILE_PATH} \\\n", - " --gin.MODEL_DIR=\\\"${MODEL_DIR}\\\" \\\n", - " --alsologtostderr\n", - "```\n", - "\n", - "For more details on training using the T5X training binaries, please see the [Pretraining](https://github.com/google-research/t5x/blob/main/docs/usage.md/pretrain) or [Finetuning](https://github.com/google-research/t5x/blob/main/docs/usage.md/finetune) tutorials." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "wi29fMdv4mSr" - }, - "source": [ - "## SeqIO\n", - "\n", - "If you are interested in T5X, you may also be interested in, or have heard of, SeqIO. SeqIO is a library for processing sequential data to be fed into downstream sequence models. At a high level, SeqIO relies on user-defined `Tasks` and `Mixtures` that can be used to retrieve and evaluate datasets.\n", - "\n", - "We won't go into details about SeqIO here; we recommend checking out this [SeqIO Introductory guide](https://github.com/google/seqio/blob/main/README.md/index) and/or clicking below to run a SeqIO Introductory Colab. The rest of this section will assume a basic understanding of SeqIO.\n", - "\n", - "\u003ca href=\"https://colab.research.google.com/github/google-research/seqio/blob/main/seqio/notebooks/Basics_Task_and_Mixtures.ipynb\" target=\"_parent\"\u003e\u003cimg src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/\u003e\u003c/a\u003e\n", - "\n", - "If you are already familiar with SeqIO and have a SeqIO task/mixture that you would like to use in this Colab, we do provide a SeqIO bridge that takes in a SeqIO task/mixture and produces batches of examples that can be processed by the code snippets above. We've provided an example of this bridge below." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "kh5VMAWcV_WG" - }, - "outputs": [], - "source": [ - "!git clone https://github.com/google-research/google-research.git google_research" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "bM0nRIEFwyj_" - }, - "outputs": [], - "source": [ - "import google_research.t5_closed_book_qa.t5_cbqa.tasks\n", - "batches = get_batches_from_seqio(\n", - " task_or_mixture_name='natural_questions_open',\n", - " split='validation',\n", - " batch_size=8,\n", - " num_batches=2,\n", - " seed=42)\n", - "print(f\"Batches: {batches}\")\n", - "# Train the interactive model on the provided batches.\n", - "original_step = interactive_model.step\n", - "_ = interactive_model.train_loop(num_steps=len(batches), train_batches=batches)\n", - "print(f\"Original Step: {original_step}, Current Step: {interactive_model.step}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Elt08160w03X" - }, - "source": [ - "The `get_batches_from_seqio` bridge can take several constructor arguments:\n", - "\n", - "\n", - "1. `task_or_mixture_name`: the name of the SeqIO task/mixture to read data from. It should be noted that your task/mixture must already be registered with SeqIO, and you must import the module that defines your task/mixture here (as seen above).\n", - "2. `split`: the split of the Task/Mixture to read data from.\n", - "3. `batch_size`: how many examples should appear in each batch.\n", - "4. `num_batches`: the total number of batches to return.\n", - "5. `get_pretokenized_examples`: optional. A boolean, defaulting to True, that determines whether we should read the `inputs_pretokenized`/`targets_pretokenized` elements from an example, or the `inputs`/`targets` elements. \\\n", - "The `train_step`, `predict`, `predict_with_aux`, `score`, and `evaluate` methods of the InteractiveModel assume that we should run [tokenization](https://github.com/google/seqio/tree/main/seqio/preprocessors.py) and [appending an EOS token](https://github.com/google/seqio/tree/main/seqio/preprocessors.py) as the only preprocessors. To use these methods with this pre-defined list of preprocessors, you can set `get_pretokenized_examples=True` to retrieve examples that still need to be tokenized, and these InteractiveModel methods will handle running these preprocessors. This setting can also be helpful if you want to inspect the natural text inputs/targets of your SeqIO task. \\\n", - "However, some SeqIO tasks do not use tokenization (ex: span corruption). You can set `get_pretokenized_examples=False`, and this bridge will read the fully preprocessed examples from the SeqIO task. You can then run `train_step_with_preprocessors`, `infer_with_preprocessors`, or `evaluate_with_preprocessors` and provide an empty preprocessors list (because all preprocessing has already been completed by this bridge) to run training/inference/evaluation. We have provided an example of using this bridge to retrieve fully preprocessed examples below.\n", - "6. `sequence_length`: optional. A dictionary mapping feature key to maximum length (int) for that feature. Used by SeqIO to retrieve the dataset/examples.\n", - "7. `**get_dataset_kwargs`: there are many [additional parameters](https://github.com/google/seqio/tree/main/seqio/dataset_providers.py) that can be set in the `SeqIO.get_dataset` function. If you would like to set any of these arguments, you can set them using this `kwargs` parameter.\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "fjKBCX39w0Xl" - }, - "outputs": [], - "source": [ - "import t5.data.tasks\n", - "batches = get_batches_from_seqio(\n", - " task_or_mixture_name='c4_v220_span_corruption',\n", - " split='validation',\n", - " batch_size=8,\n", - " num_batches=1,\n", - " get_pretokenized_examples=False,\n", - " sequence_length=interactive_model._task_feature_lengths,\n", - " seed=42)\n", - "batch = batches[0] # We expect only a single batch.\n", - "original_step = interactive_model.step\n", - "interactive_model.train_step_with_preprocessors(\n", - " examples=batch, preprocessors=[])\n", - "print(f\"Original Step: {original_step}, Current Step: {interactive_model.step}\")" - ] - } - ], - "metadata": { - "colab": { - "last_runtime": { - "build_target": "//learning/grp/tools/ml_python:ml_notebook", - "kind": "private" - }, - "name": "Welcome to T5X: Training Deep Dive", - "private_outputs": true, - "provenance": [ - { - "file_id": "1hQO9MD6psZtTeqZyXPJIoUV0uzTa2qPg", - "timestamp": 1662951508591 - }, - { - "file_id": "1Akpc6pKlJB5rn5YYYFC9lw2OMk6oBzlQ", - "timestamp": 1662754223629 - }, - { - "file_id": "1rA8bgO2bJRoebAuS96Ji0RUhnawgBY4i", - "timestamp": 1650477076639 - } - ] - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "name": "python" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} diff --git a/t5x-main/t5x/optimizers.py b/t5x-main/t5x/optimizers.py deleted file mode 100644 index 8b1283adf6c972c2febfc914b1a80c507165aba0..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/optimizers.py +++ /dev/null @@ -1,796 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""T5X Optimizer Support. - -Tools for wrapping Optax optimizers and handling SPMD annotations for use with -pjit. - -Additional support for the legacy Adafactor implementation. -""" - -import functools -from typing import Any, Mapping, Optional, Sequence, Tuple, Union - -import flax -# just used for transitional type definitions -from flax import serialization -from flax import struct -from flax import traverse_util -from flax.core import frozen_dict -from flax.serialization import from_state_dict -from flax.serialization import to_state_dict -import jax -import jax.numpy as jnp -from jestimator import amos -from jestimator import amos_helper -import optax - -freeze = flax.core.frozen_dict.freeze -unfreeze = flax.core.frozen_dict.unfreeze - -Dtype = Any - - -@struct.dataclass -class OptimizerState: - step: jnp.ndarray - param_states: Any - - -class OptimizerDef: - """Base class for an optimizer definition.""" - - def __init__(self, hyper_params): - self.hyper_params = hyper_params - - def apply_gradient(self, hyper_params, params, state, grads): - """Applies a gradient for a set of parameters.""" - raise NotImplementedError() - - def init_state(self, params): - raise NotImplementedError() - - def update_hyper_params(self, **hyper_param_overrides): - """Updates the hyper parameters with a set of overrides. - - Args: - **hyper_param_overrides: the hyper parameters updates will override the - defaults specified in the `OptimizerDef`. Pass `hyper_params=...` to - replace all hyper parameters. - - Returns: - The new hyper parameters. - """ - hp = hyper_param_overrides.pop('hyper_params', self.hyper_params) - if hyper_param_overrides: - hp = hp.replace(**hyper_param_overrides) - return hp - - def create(self, target): - """Creates a new optimizer for the given target. - - Args: - target: the object to be optimized. This is typically a variable dict - returned by `flax.linen.Module.init()`, but it can also be a container - of variables dicts, e.g. `(v1, v2)` and `('var1': v1, 'var2': v2)` are - valid inputs as well. - - Returns: - An instance of `Optimizer`. - """ - opt_def = self - state = opt_def.init_state(target) - return Optimizer(opt_def, state, target) - - def state_dict(self, target, state): - return to_state_dict( - {'target': to_state_dict(target), 'state': to_state_dict(state)} - ) - - def restore_state(self, opt_target, opt_state, state_dict): - """Restore the optimizer target and state from the state dict. - - Args: - opt_target: the optimizer target. - opt_state: the optimizer state. - state_dict: the state dict containing the desired new state of the - optimizer. - - Returns: - a tuple of the optimizer target and state with the restored values from - the state dict. - """ - - opt_target = from_state_dict(opt_target, state_dict['target']) - opt_state = from_state_dict(opt_state, state_dict['state']) - return opt_target, opt_state - - -class Optimizer(struct.PyTreeNode): - """Legacy flax optimizer class. - - Optimizer carries the target and optimizer state. The optimizer is updated - using the method apply_gradient. - - Attributes: - optimizer_def: The optimizer definition. - state: The initial state of the optimizer. - target: The target to optimizer. - """ - - optimizer_def: OptimizerDef = struct.field(pytree_node=False) - state: Any = struct.field(pytree_node=True) - target: Any = struct.field(pytree_node=True) - - def apply_gradient(self, grads, **hyper_param_overrides): - """Applies a pytree of gradients to the target. - - Args: - grads: A pytree of gradients. - **hyper_param_overrides: the hyper parameters passed to apply_gradient - will override the defaults specified in the `OptimizerDef`. Pass - `hyper_params=...` to replace all hyper parameters. - - Returns: - A new optimizer with the updated target and state. - """ - hyper_params = self.optimizer_def.update_hyper_params( - **hyper_param_overrides - ) - new_target, new_state = self.optimizer_def.apply_gradient( - hyper_params, self.target, self.state, grads - ) - return self.replace(target=new_target, state=new_state) - - def state_dict(self): - return self.optimizer_def.state_dict(self.target, self.state) - - def restore_state(self, state): - target, state = self.optimizer_def.restore_state( - self.target, self.state, state - ) - return self.replace(target=target, state=state) - - -# Transitional Type Definitions - -OptimizerType = Optimizer -OptimizerStateType = Union[OptimizerState, Mapping[str, Any]] -OptimizerDefType = OptimizerDef - - -# Optax Elementwise Wrapper - - -def _scale_by_schedule_ctor(state, params_axes): - del state, params_axes - return optax.ScaleByScheduleState( # pytype: disable=wrong-arg-types # numpy-scalars - count=None - ) - - -class OptaxStatePartitionRules: - """Collection of rules to partition optax states. - - These rules work for optimizers whose states are simply replications of - params, e.g., Adam. Optimizers that aim to save memory by factoring states, - e.g., Adafactor, SM3, are not supported currently. - """ - - # Rules mapping a particular optax state to a callable returning the state - # with arrays replaced by t5x PartitionSpec or None. - # - # NOTE(levskaya): This is not an entirely exhaustive list, add to this list - # to support additional optimizers / transformations. - # - # pylint: disable=g-long-lambda - - _RULES = { - # Leaf Optax States: - amos.ScaleByAmosState: amos_helper.state_partition_rule, - optax.AddNoiseState: lambda state, params_axes: optax.AddNoiseState( # pytype: disable=wrong-arg-types # numpy-scalars - count=None, rng_key=None - ), - optax.DifferentiallyPrivateAggregateState: ( - lambda state, params_axes: optax.DifferentiallyPrivateAggregateState( - rng_key=None - ) - ), - optax.EmaState: lambda state, params_axes: optax.EmaState( # pytype: disable=wrong-arg-types # numpy-scalars - count=None, - ema=OptaxStatePartitionRules.derive_params_axes( - state.ema, params_axes - ), - ), - optax.EmptyState: lambda state, params_axes: optax.EmptyState(), - optax.TraceState: lambda state, params_axes: optax.TraceState( - trace=OptaxStatePartitionRules.derive_params_axes( - state.trace, params_axes - ) - ), - optax.ScaleByAdamState: lambda state, params_axes: optax.ScaleByAdamState( # pytype: disable=wrong-arg-types # numpy-scalars - count=None, - mu=OptaxStatePartitionRules.derive_params_axes(state.mu, params_axes), - nu=OptaxStatePartitionRules.derive_params_axes(state.nu, params_axes), - ), - optax.ScaleByBeliefState: lambda state, params_axes: optax.ScaleByBeliefState( # pytype: disable=wrong-arg-types # numpy-scalars - count=None, - mu=OptaxStatePartitionRules.derive_params_axes(state.mu, params_axes), - nu=OptaxStatePartitionRules.derive_params_axes(state.nu, params_axes), - ), - optax.ScaleByLionState: lambda state, params_axes: optax.ScaleByLionState( # pytype: disable=wrong-arg-types # numpy-scalars - count=None, - mu=OptaxStatePartitionRules.derive_params_axes(state.mu, params_axes), - ), - optax.ScaleByRssState: lambda state, params_axes: optax.ScaleByRssState( - sum_of_squares=OptaxStatePartitionRules.derive_params_axes( - state.sum_of_squares, params_axes - ) - ), - optax.ScaleByRmsState: lambda state, params_axes: optax.ScaleByRmsState( - nu=OptaxStatePartitionRules.derive_params_axes(state.nu, params_axes) - ), - optax.ScaleByRStdDevState: ( - lambda state, params_axes: optax.ScaleByRStdDevState( - mu=OptaxStatePartitionRules.derive_params_axes( - state.mu, params_axes - ), - nu=OptaxStatePartitionRules.derive_params_axes( - state.nu, params_axes - ), - ) - ), - optax.ScaleBySM3State: lambda state, params_axes: optax.ScaleBySM3State( - mu=OptaxStatePartitionRules.derive_params_axes(state.mu, params_axes), - nu=OptaxStatePartitionRules.derive_params_axes(state.nu, params_axes), - ), - optax.ScaleByTrustRatioState: ( - lambda state, params_axes: optax.ScaleByTrustRatioState() - ), - optax.ScaleByScheduleState: _scale_by_schedule_ctor, - optax.ZeroNansState: lambda state, params_axes: optax.ZeroNansState( - found_nan=None - ), - # FactoredState - # Recursive, Combinator Optax States: - # MaskedState - optax.MaskedState: lambda state, params_axes: optax.MaskedState( - inner_state=OptaxStatePartitionRules.derive_optax_logical_axes( - state.inner_state, params_axes - ) - ), - optax.InjectHyperparamsState: lambda state, params_axes: optax.InjectHyperparamsState( # pytype: disable=wrong-arg-types # jax-ndarray - count=None, - hyperparams=jax.tree.map(lambda x: None, state.hyperparams), - inner_state=OptaxStatePartitionRules.derive_optax_logical_axes( - state.inner_state, params_axes - ), - ), - optax.MultiStepsState: lambda state, params_axes: optax.MultiStepsState( # pytype: disable=wrong-arg-types # jax-ndarray - mini_step=None, - gradient_step=None, - inner_opt_state=OptaxStatePartitionRules.derive_optax_logical_axes( # pylint: disable=line-too-long - state.inner_opt_state, params_axes - ), - acc_grads=params_axes, - ), - optax.ApplyIfFiniteState: ( - lambda state, params_axes: optax.ApplyIfFiniteState( - notfinite_count=None, - last_finite=None, - total_notfinite=None, - inner_state=OptaxStatePartitionRules.derive_optax_logical_axes( - state.inner_state, params_axes - ), - ) - ), - optax.ConditionallyTransformState: lambda state, params_axes: optax.ConditionallyTransformState( # pytype: disable=wrong-arg-types # jax-ndarray - inner_state=OptaxStatePartitionRules.derive_optax_logical_axes( - state.inner_state, params_axes - ), - step=None, - ), - optax.MultiTransformState: ( - lambda state, params_axes: optax.MultiTransformState( - inner_states=OptaxStatePartitionRules.derive_optax_logical_axes( - state.inner_states, params_axes - ) - ) - ), - # LookaheadState - # SplitRealAndImaginaryState - } - # pylint: enable=g-long-lambda - - @classmethod - def _is_optax_state(cls, x): - """Returns true if an object is an optax state. - - Note that in optax states are simply derived from NamedTuple, so we have to - do some hacky name matching. - - Args: - x: object. - - Returns: - True if x is an optax state. - """ - # A solution from stack overflow. Note that isinstance(x, NamedTuple) would - # not work. - is_named_tuple = ( - isinstance(x, tuple) and hasattr(x, '_asdict') and hasattr(x, '_fields') - ) - result = is_named_tuple and type(x).__name__.endswith('State') - return result - - @classmethod - def derive_optax_logical_axes(cls, optax_state, params_axes): - """Derived logical axes for optax state.""" - # Flatten the optax state but do not go into the registered states. - flattened_state, tree_def = jax.tree_util.tree_flatten( - optax_state, is_leaf=cls._is_optax_state - ) - - def derive_fn(x): - if type(x) not in cls._RULES: - if cls._is_optax_state(x): - raise ValueError( - f'Encountered unregistered optax state type {type(x).__name__}' - ) - return None - return cls._RULES[type(x)](x, params_axes) - - flattened_axes = [derive_fn(x) for x in flattened_state] - derived_axes = jax.tree_util.tree_unflatten(tree_def, flattened_axes) - return derived_axes - - @classmethod - def derive_params_axes(cls, optax_params, params_axes): - """Derive axes for params inside optax state.""" - # Params masked by optax should not have a corresponding PartitionSpec. - return jax.tree_util.tree_map( - lambda x, y: x if not isinstance(y, optax.MaskedNode) else y, - params_axes, - optax_params, - ) - - -@struct.dataclass -class _OptaxWrapperHyperParams: - """Dummy hyper params struct, not used.""" - - # Required by t5x trainer. Unused as learning rate scheduling is done using - # optax.Schedule. - learning_rate: Optional[float] = None - - -class OptaxWrapper(OptimizerDef): - """Wrapper to make optax optimizer compatible with T5X.""" - - def __init__(self, optax_optimizer: optax.GradientTransformation): - """Initializer. - - Args: - optax_optimizer: An optax optimizer. - """ - self.optax_optimizer = optax_optimizer - super().__init__(hyper_params=_OptaxWrapperHyperParams()) - - def init_state(self, params): - """Create initial state based on the params to optimize. - - Args: - params: PyTree of parameters to optimize. - - Returns: - Initial optimizer state. - """ - state = OptimizerState( # pytype: disable=wrong-arg-types # jax-ndarray - step=0, param_states=self.optax_optimizer.init(params) - ) - return state - - def apply_gradient(self, hyper_params, params, state, grads): - """Applies gradient. - - Args: - hyper_params: Unused hyper parameters. - params: PyTree of the parameters. - state: A named tuple containing the state of the optimizer. - grads: PyTree of the gradients for the parameters. - - Returns: - A tuple containing the new parameters and the new optimizer state. - """ - del hyper_params - - updates, new_optax_state = self.optax_optimizer.update( - grads, state.param_states, params - ) - new_params = optax.apply_updates(params, updates) - return new_params, OptimizerState( - step=state.step + 1, param_states=new_optax_state - ) - - def derive_logical_axes(self, optimizer, param_logical_axes): - """Derives optimizer state logical axes from params logical axes. - - Args: - optimizer: `optimizers.Optimizer` instance. - param_logical_axes: A PyTree where each leaf is a t5x PartitionSpec. - - Returns: - An `optimizers.Optimizer` instance, with all the leafs replaced by t5x - PartitionSpec or None (no partition). - """ - optimizer_logical_axes = jax.tree.map( - lambda x: None, optimizer.state_dict() - ) - optimizer_logical_axes['target'] = param_logical_axes - - optax_state_axes = OptaxStatePartitionRules.derive_optax_logical_axes( - optimizer.state.param_states, param_logical_axes - ) - - optimizer_logical_axes['state']['param_states'] = ( - serialization.to_state_dict(optax_state_axes) - ) - - return optimizer.restore_state(frozen_dict.unfreeze(optimizer_logical_axes)) - - def state_dict(self, target, state): - """Override state dict function. - - We need to override this function because many optax transformations use - `optax.EmptyState`, which produces empty dict in the state dict. This causes - the T5 training loop to fail in multiple places. As a remedy, we will - filter out the generated state dict so that there are no empty dict in the - output. - - The restore_state function is also overridden to reconstruct those empty - dict. - - Args: - target: Pytree of target variables. - state: Pytree of optimizer state. - - Returns: - A nested state. - """ - state_dict = to_state_dict(state) - - # This step removes any empty dict (recursively) in the state dict. - state_dict = traverse_util.unflatten_dict( - traverse_util.flatten_dict(state_dict, sep='/'), sep='/' - ) - - return to_state_dict({ - 'target': to_state_dict(target), - 'state': state_dict, - }) - - def restore_state(self, opt_target, opt_state, state_dict): - """Override to restore empty dicts corresponding to `optax.EmptyState`. - - Args: - opt_target: the optimizer target. - opt_state: the optimizer state. - state_dict: the state dict containing the desired new state of the - optimizer. - - Returns: - a tuple of the optimizer target and state with the restored values from - the state dict. - """ - opt_target = from_state_dict(opt_target, state_dict['target']) - - # Get all the possible keys in the reference optimizer state. - flat_ref_opt_state_dict = traverse_util.flatten_dict( - to_state_dict(opt_state), keep_empty_nodes=True, sep='/' - ) - - flat_src_opt_state_dict = dict( - traverse_util.flatten_dict(state_dict['state'], sep='/') - ) - # Adding the empty paths back to flat_src_opt_state_dict. - for k, v in flat_ref_opt_state_dict.items(): - if k in flat_src_opt_state_dict: - continue - # The key is not in the input state dict, presumably because it - # corresponds to an empty dict. - if v != traverse_util.empty_node: - raise ValueError( - f'Failed to restore optimizer state, path {k} is not present ' - 'in the input optimizer state dict.' - ) - flat_src_opt_state_dict[k] = v - - # Restore state from the enhanced state dict. - opt_state = from_state_dict( - opt_state, - traverse_util.unflatten_dict(flat_src_opt_state_dict, sep='/'), - ) - return opt_target, opt_state - - -# Optax wrapper and elementary wrapped optax optimizers. - - -def wrap_optax_optimizer(optax_optimizer): - """Converts optax optimizer constructor to a wrapped T5X-compatible optimizer. - - Args: - optax_optimizer: an optax optimizer creation function that returns an optax - GradientTransformation. - - Returns: - A function that takes the same arguments as the original optax creation - function but instead returns a wrapped OptimizerDef-compatible interface for - using the optimizer with T5X. - """ - - @functools.wraps(optax_optimizer) - def wrapped_optimizer(*args, **kwargs) -> OptimizerDef: - return OptaxWrapper(optax_optimizer(*args, **kwargs)) - - return wrapped_optimizer - - -def chain( - transformations: Sequence[optax.GradientTransformation], -) -> optax.GradientTransformation: - return optax.chain(*transformations) - - -chain = wrap_optax_optimizer(chain) -adabelief = wrap_optax_optimizer(optax.adabelief) -adagrad = wrap_optax_optimizer(optax.adagrad) -adam = wrap_optax_optimizer(optax.adam) -adamw = wrap_optax_optimizer(optax.adamw) -amos = wrap_optax_optimizer(amos.amos) -fromage = wrap_optax_optimizer(optax.fromage) -lars = wrap_optax_optimizer(optax.lars) -lamb = wrap_optax_optimizer(optax.lamb) -lion = wrap_optax_optimizer(optax.lion) -noisy_sgd = wrap_optax_optimizer(optax.noisy_sgd) -radam = wrap_optax_optimizer(optax.radam) -rmsprop = wrap_optax_optimizer(optax.rmsprop) -sgd = wrap_optax_optimizer(optax.sgd) -yogi = wrap_optax_optimizer(optax.yogi) -dpsgd = wrap_optax_optimizer(optax.dpsgd) - -# Excluded optimizers: -# TODO(levskaya): add shampoo, sm3 -# We use our own generalized adafactor implementations. -# adafactor = wrap_optax_optimizer(optax.adafactor) -# We may use a more complete quantized implementation of SM3 -# sm3 = wrap_optax_optimizer(optax.sm3) - -# Inlined Legacy Generalized Multioptimizer - - -class _Marker: - """Used to mark unoptimized leaves.""" - - def __init__(self): - self._indices = [] - - -def _tree_of_paths(tree): - """Converts a (frozen) nested dictionary into a (frozen) dict of paths.""" - is_frozen = isinstance(tree, flax.core.frozen_dict.FrozenDict) - flat_tree = traverse_util.flatten_dict(unfreeze(tree)) - path_tree = traverse_util.unflatten_dict( - {k: '/'.join(k) for k in flat_tree.keys()} - ) - if is_frozen: - path_tree = freeze(path_tree) - return path_tree - - -def _subtree_from_traversal(traversal, tree): - """Creates a (frozen) tree subset given a traversal.""" - is_frozen = isinstance(tree, flax.core.frozen_dict.FrozenDict) - flat_tree = {} - for path, leaf in zip( - traversal.iterate(_tree_of_paths(tree)), traversal.iterate(tree) - ): - flat_tree[path] = leaf - new_tree = traverse_util.unflatten_dict( - {tuple(k.split('/')): v for k, v in flat_tree.items()} - ) - if is_frozen: - new_tree = freeze(new_tree) - return new_tree - - -def _update_subtree_of_traversal(traversal, tree, update): - """Updates a (frozen) tree's subset given a traversal and update subtree.""" - is_frozen = isinstance(tree, flax.core.frozen_dict.FrozenDict) - flat_tree = traverse_util.flatten_dict(unfreeze(tree)) - flat_tree = {'/'.join(k): v for k, v in flat_tree.items()} - for path, leaf in zip( - traversal.iterate(_tree_of_paths(update)), traversal.iterate(update) - ): - flat_tree[path] = leaf - nested_d = traverse_util.unflatten_dict( - {tuple(k.split('/')): v for k, v in flat_tree.items()} - ) - if is_frozen: - nested_d = freeze(nested_d) - return nested_d - - -class MultiOptimizer(OptimizerDef): - """Generalized Multioptimizer. - - NB: Although this is provided for legacy support, it is still quite general - and should work fine with wrapped optax optimizers. But do note that the more - canonical way of mixing multiple optimizers inside optax uses optax.masked or - optax.multi_transform instead. - - A MultiOptimizer is subclass of :class:`OptimizerDef` and useful for applying - separate optimizer algorithms to various subsets of the model parameters. - - The example below creates two optimizers using - :class:`flax.traverse_util.ModelParamTraversal`: - one to optimize ``kernel`` parameters and to optimize ``bias`` parameters. - Note each optimizer is created with a different learning rate:: - - kernels = traverse_util.ModelParamTraversal( - lambda path, _: 'kernel' in path) - biases = traverse_util.ModelParamTraversal(lambda path, _: 'bias' in path) - kernel_opt = optimizers.adam(learning_rate=0.01) - bias_opt = optimizers.adam(learning_rate=0.1) - opt_def = MultiOptimizer((kernels, kernel_opt), (biases, bias_opt)) - optimizer = opt_def.create(model) - - In order to train only a subset of the parameters, you can simply use a single - :class:`flax.traverse_util.ModelParamTraversal` instance. - - If you want to update the learning rates of both optimizers online with - different learning rate schedules, you should update the learning rates when - applying the gradient. In the following example, the second optimizer is not - doing any optimization during the first 1000 steps:: - - hparams = optimizer.optimizer_def.hyper_params - new_optimizer = optimizer.apply_gradient( - grads, - hyper_params=[ - hparams[0].replace(learning_rate=0.2), - hparams[1].replace(learning_rate=jnp.where(step < 1000, 0., lr)), - ]) - """ - - def __init__( - self, - traversals_and_optimizers: Sequence[ - Tuple[traverse_util.Traversal, OptimizerDef] - ], - ): - """Create a new MultiOptimizer. - - See docstring of :class:`MultiOptimizer` for more details. - - Args: - traversals_and_optimizers: pairs of flax.traverse_util.Traversal and - `optimizers.OptimizerDef` instances. - """ - traversals, sub_optimizers = zip(*traversals_and_optimizers) - hyper_params = [opt.hyper_params for opt in sub_optimizers] - super().__init__(hyper_params) - self.traversals = traversals - self.sub_optimizers = sub_optimizers - - def init_state(self, params): - param_states = jax.tree.map(lambda x: _Marker(), params) - overlap = False - for idx, traversal in enumerate(self.traversals): - for match in traversal.iterate(param_states): - match._indices.append(idx) # pylint: disable=protected-access - overlap |= len(match._indices) > 1 # pylint: disable=protected-access - if overlap: - raise ValueError( - 'Multiple optimizers match the same leaves : ' - + str(jax.tree.map(lambda match: match._indices, param_states)) # pylint: disable=protected-access - ) - - param_states = jax.tree.map(lambda x: _Marker(), params) - for focus, opt_def in zip(self.traversals, self.sub_optimizers): - ps = _subtree_from_traversal(focus, params) - ss = opt_def.init_state(ps) - param_states = _update_subtree_of_traversal( - focus, param_states, ss.param_states - ) - # Update state to None when param is not optimized by any sub optimizer. - param_states = jax.tree.map( - lambda x: (None if isinstance(x, _Marker) else x), param_states - ) - return OptimizerState(jnp.asarray(0, dtype=jnp.int32), param_states) - - def apply_gradient(self, hyper_params, params, state, grads): - new_params = params - it = zip(self.traversals, self.sub_optimizers, hyper_params) - new_param_states = jax.tree.map(lambda x: _Marker(), params) - for focus, opt_def, hp in it: - ps = _subtree_from_traversal(focus, params) - gs = _subtree_from_traversal(focus, grads) - ss = _subtree_from_traversal(focus, state.param_states) - prev_ss = OptimizerState(state.step, ss) - new_ps, new_ss = opt_def.apply_gradient(hp, ps, prev_ss, gs) - new_params = _update_subtree_of_traversal(focus, new_params, new_ps) - new_param_states = _update_subtree_of_traversal( - focus, new_param_states, new_ss.param_states - ) - # Update state to None when param is not optimized by any sub optimizer. - new_param_states = jax.tree.map( - lambda x: (None if isinstance(x, _Marker) else x), new_param_states - ) - return new_params, OptimizerState(state.step + 1, new_param_states) - - def update_hyper_params(self, **hyper_param_overrides): - """Updates the hyper parameters with a set of overrides. - - This method is called from :meth:`Optimizer.apply_gradient` to create the - hyper parameters for a specific optimization step. - MultiOptimizer will apply the overrides for each sub optimizer. - - Args: - **hyper_param_overrides: the hyper parameters updates will override the - defaults specified in the `OptimizerDef`. Pass `hyper_params=...` to - replace all hyper parameters. - - Returns: - The new hyper parameters. - """ - hps = hyper_param_overrides.pop('hyper_params', self.hyper_params) - if hyper_param_overrides: - hps = [hp.replace(**hyper_param_overrides) for hp in hps] - return hps - - def set_param_axes(self, param_logical_axes): - """Derives factorization rules from model parameter logical axes.""" - for focus, opt_def in zip(self.traversals, self.sub_optimizers): - pla_subtree = _subtree_from_traversal(focus, param_logical_axes) - if hasattr(opt_def, 'set_param_axes'): - opt_def.set_param_axes(pla_subtree) - - def derive_logical_axes(self, optimizer, param_logical_axes): - """Derives optimizer logical partitioning from model logical partitions.""" - param_states = jax.tree.map( - lambda x: _Marker(), optimizer.state.param_states - ) - for focus, opt_def in zip(self.traversals, self.sub_optimizers): - if hasattr(opt_def, 'derive_logical_axes'): - ps = _subtree_from_traversal(focus, param_logical_axes) - ss = _subtree_from_traversal(focus, optimizer.state.param_states) - new_opt = opt_def.derive_logical_axes( - Optimizer(opt_def, OptimizerState(None, ss), ps), ps # pytype: disable=wrong-arg-types # jax-ndarray - ) - param_states = _update_subtree_of_traversal( - focus, param_states, new_opt.state.param_states - ) - # Update axes to None when param is not optimized by any sub optimizer. - param_states = jax.tree.map( - lambda x: (None if isinstance(x, _Marker) else x), param_states - ) - return Optimizer( - optimizer.optimizer_def, - OptimizerState(None, param_states), # pytype: disable=wrong-arg-types # jax-ndarray - param_logical_axes, - ) - - # TODO(levskaya): add traversal handling for state_dict / restore_state - # this is required to make this work w. optax optimizers... diff --git a/t5x-main/t5x/optimizers_test.py b/t5x-main/t5x/optimizers_test.py deleted file mode 100644 index e02e40d2f6dbf3200870efa31ca491acb53b114a..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/optimizers_test.py +++ /dev/null @@ -1,361 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for t5x.optimizers.""" - -import dataclasses -import functools -import operator - -from absl.testing import absltest -from absl.testing import parameterized -import chex -import flax -from flax.core import frozen_dict -import jax -import jax.numpy as jnp -import numpy as np -import optax -import seqio -from t5x import models -from t5x import optimizers -from t5x import partitioning -from t5x import test_utils -from t5x import trainer -from t5x import utils -from t5x.examples.t5 import network - - -def _assert_numpy_allclose(a, b, atol=None, rtol=None): - a, b = jnp.array(a), jnp.array(b) - a = a.astype(np.float32) if a.dtype == jnp.bfloat16 else a - b = b.astype(np.float32) if b.dtype == jnp.bfloat16 else b - kw = {} - if atol: - kw['atol'] = atol - if rtol: - kw['rtol'] = rtol - np.testing.assert_allclose(a, b, **kw) - - -def check_eq(xs, ys, atol=None, rtol=None): - xs_leaves, xs_tree = jax.tree_util.tree_flatten(xs) - ys_leaves, ys_tree = jax.tree_util.tree_flatten(ys) - assert xs_tree == ys_tree, f"Tree shapes don't match. \n{xs_tree}\n{ys_tree}" - assert jax.tree_util.tree_all( - jax.tree.map( - lambda x, y: np.array(x).shape == np.array(y).shape, - xs_leaves, - ys_leaves, - ) - ), "Leaves' shapes don't match." - assert jax.tree.map( - functools.partial(_assert_numpy_allclose, atol=atol, rtol=rtol), - xs_leaves, - ys_leaves, - ) - - -def flattened_state_dict(x): - s = flax.serialization.to_state_dict(x) - return flax.traverse_util.flatten_dict(s, sep='/') - - -def tree_shape(x): - return jax.tree.map(jnp.shape, x) - - -def tree_equals(x, y): - return jax.tree_util.tree_all(jax.tree.map(operator.eq, x, y)) - - -def get_fake_tokenized_dataset_no_pretokenized(*_, split='validation', **__): - return test_utils.get_fake_tokenized_dataset(split=split).map( - lambda x: {k: v for k, v in x.items() if not k.endswith('_pretokenized')} - ) - - -def get_t5_test_model( - optimizer_def, **config_overrides -) -> models.EncoderDecoderModel: - """Returns a tiny T5 1.1 model to use for testing.""" - tiny_config = network.T5Config( - vocab_size=128, - dtype='bfloat16', - emb_dim=8, - num_heads=4, - num_encoder_layers=2, - num_decoder_layers=2, - head_dim=3, - mlp_dim=16, - mlp_activations=('gelu', 'linear'), - dropout_rate=0.0, - logits_via_embedding=False, - ) - tiny_config = dataclasses.replace(tiny_config, **config_overrides) - vocabulary = test_utils.get_fake_vocab() - return models.EncoderDecoderModel( - module=network.Transformer(tiny_config), - input_vocabulary=vocabulary, - output_vocabulary=vocabulary, - optimizer_def=optimizer_def, - ) - - -def sgd_with_multi_transform(): - """Uses optax.multi_transform to train only decoder parameters.""" - - def _mask_fn(params): - mask = jax.tree_util.tree_map(lambda _: False, params) - mask = mask.copy( - {'decoder': jax.tree_util.tree_map(lambda _: True, mask['decoder'])} - ) - return mask - - return optax.multi_transform( - { - False: optax.set_to_zero(), - True: optax.sgd(1e-2, 0.0), - }, - _mask_fn, - ) - - -class BasicTest(chex.TestCase): - - @classmethod - def get_params(cls): - return frozen_dict.FrozenDict({ - 'forward': { - 'input_layer': { - 'embedding': jnp.zeros([16, 8], dtype=jnp.float32), - }, - 'output_layer': { - 'layer_norm': { - 'scale': jnp.zeros([8], dtype=jnp.float32), - }, - 'proj': { - 'bias': jnp.zeros([1], dtype=jnp.float32), - 'kernel': jnp.zeros([8, 1], dtype=jnp.float32), - }, - }, - }, - 'loss': { - 'loss_fn': { - 'loss_biases': jnp.zeros([2], dtype=jnp.float32), - }, - }, - }) - - @classmethod - def get_params_shapes(cls): - return jax.tree.map(jnp.shape, cls.get_params()) - - @classmethod - def get_param_logical_axes(cls): - return frozen_dict.FrozenDict({ - 'forward': { - 'input_layer': { - 'embedding': partitioning.PartitionSpec('vocab', 'embed'), - }, - 'output_layer': { - 'layer_norm': { - 'scale': partitioning.PartitionSpec( - 'embed', - ), - }, - 'proj': { - 'bias': partitioning.PartitionSpec( - 'output_head', - ), - 'kernel': partitioning.PartitionSpec( - 'embed', 'output_head' - ), - }, - }, - }, - 'loss': { - 'loss_fn': { - 'loss_biases': partitioning.PartitionSpec( - 'unmodeled', - ), - }, - }, - }) - - def test_logical_axes_adamw(self): - opt = optax.adamw(0.001, weight_decay=0.001) - wrapper = optimizers.OptaxWrapper(opt) - optimizer = wrapper.create(self.get_params()) - got = wrapper.derive_logical_axes(optimizer, self.get_param_logical_axes()) - want = optimizers.Optimizer( - optimizer_def=wrapper, - state=optimizers.OptimizerState( - step=None, - param_states=( - optax.ScaleByAdamState( - count=None, - mu=self.get_param_logical_axes(), - nu=self.get_param_logical_axes(), - ), - optax.EmptyState(), - optax.EmptyState(), - ), - ), - target=self.get_param_logical_axes(), - ) - chex.assert_trees_all_equal(got, want) - - @parameterized.parameters( - ('sgd', lambda: optax.sgd(1e-2, 0.0)), - ('adam', lambda: optax.adam(1e-1)), - ('adamw', lambda: optax.adamw(1e-1)), - ('lamb', lambda: optax.adamw(1e-1)), - ('lion', lambda: optax.lion(1e-2)), - ('rmsprop', lambda: optax.rmsprop(1e-1)), - ('rmsprop_momentum', lambda: optax.rmsprop(5e-2, momentum=0.9)), - ('fromage', lambda: optax.fromage(1e-2)), - ('adabelief', lambda: optax.adabelief(1e-1)), - ('radam', lambda: optax.radam(1e-1)), - ('yogi', lambda: optax.yogi(1.0)), - ) - def test_sanity_check_logical_axes(self, opt_name, opt_fn): - opt = opt_fn() - - wrapper = optimizers.OptaxWrapper(opt) - optimizer = wrapper.create(self.get_params()) - _ = wrapper.derive_logical_axes(optimizer, self.get_param_logical_axes()) - - # TODO(rosun): basic sanity check, we just want to make sure if a param - # name, e.g., `loss_biases` appear in the tree, the corresponding value is - # always a PartitionSpec. - - def test_adamw_state_serialization(self): - opt = optax.adamw(0.001, weight_decay=0.001) - wrapper = optimizers.OptaxWrapper(opt) - optimizer = wrapper.create(self.get_params()) - - state_dict = optimizer.state_dict() - - chex.assert_trees_all_equal( - frozen_dict.FrozenDict(jax.tree.map(jnp.shape, state_dict)), - frozen_dict.FrozenDict({ - 'target': self.get_params_shapes(), - 'state': { - 'step': (), - 'param_states': { - '0': { - 'count': (), - 'mu': self.get_params_shapes(), - 'nu': self.get_params_shapes(), - }, - # NB: We eliminate empty tuple leaves from EmptyState() in - # OptaxWrapper to avoid having the rest of T5X have to - # correctly handle this detail. e.g. we omit these: - # '1': {}, - # '2': {}, - }, - }, - }), - ) - - new_optimizer = optimizer.restore_state(state_dict) - - chex.assert_trees_all_equal(optimizer, new_optimizer) - - -class OptaxWrapperTest(chex.TestCase): - - def run_train_loop(self, optimizer_def): - # Construct input data. - - ds = get_fake_tokenized_dataset_no_pretokenized(split='validation') - ds = seqio.EncDecFeatureConverter()( - ds, task_feature_lengths={'inputs': 8, 'targets': 8} - ) - ds = ds.repeat().batch(8) - ds_iter = ds.as_numpy_iterator() - first_batch = next(ds_iter) - - model = get_t5_test_model(optimizer_def, vocab_size=128) - - learning_rate_fn = utils.create_learning_rate_scheduler() - - input_shapes = jax.tree.map(jnp.shape, first_batch) - input_types = jax.tree.map(lambda x: jnp.dtype(x.dtype), first_batch) - - partitioner = partitioning.PjitPartitioner( - num_partitions=2, - logical_axis_rules=partitioning.standard_logical_axis_rules(), - ) - - train_state_initializer = utils.TrainStateInitializer( - optimizer_def=model.optimizer_def, - init_fn=model.get_initial_variables, - input_shapes=input_shapes, - input_types=input_types, - partitioner=partitioner, - ) - - train_state_axes = train_state_initializer.train_state_axes - train_state = train_state_initializer.from_scratch(jax.random.PRNGKey(0)) - - trainer_instance = trainer.Trainer( - model, - train_state=train_state, - partitioner=partitioner, - eval_names=[], - summary_dir=None, - train_state_axes=train_state_axes, - rng=jax.random.PRNGKey(0), - learning_rate_fn=learning_rate_fn, - num_microbatches=1, - ) - - chex.assert_tree_all_finite(trainer_instance.train_state.params) - for _ in range(2): - trainer_instance.train(ds_iter, 1) - chex.assert_tree_all_finite(trainer_instance.train_state.params) - - # check save/restore structural equality - restored_instance = trainer_instance.train_state.restore_state( - trainer_instance.train_state.state_dict() - ) - chex.assert_trees_all_equal_structs( - trainer_instance.train_state, restored_instance - ) - - # NOTE(levskaya): these are surprisingly slow tests on CPU. - @parameterized.parameters( - ('sgd', lambda: optax.sgd(1e-2, 0.0)), - ('adam', lambda: optax.adam(1e-1)), - ('adamw', lambda: optax.adamw(1e-1)), - ('lamb', lambda: optax.adamw(1e-1)), - ('lion', lambda: optax.lion(1e-2)), - # ('rmsprop', lambda: optax.rmsprop(1e-1)), - # ('rmsprop_momentum', lambda: optax.rmsprop(5e-2, momentum=0.9)), - # ('fromage', lambda: optax.fromage(1e-2)), - ('adabelief', lambda: optax.adabelief(1e-1)), - # ('radam', lambda: optax.radam(1e-1)), - ('yogi', lambda: optax.yogi(1.0)), - ('multi_transform', sgd_with_multi_transform), - ) - def test_optimizer(self, opt_name, opt_fn): - opt = opt_fn() - optimizer_def = optimizers.OptaxWrapper(opt) - self.run_train_loop(optimizer_def) - - -if __name__ == '__main__': - absltest.main() diff --git a/t5x-main/t5x/partitioning.py b/t5x-main/t5x/partitioning.py deleted file mode 100644 index 2ea7b945575547f0b2172a9fe7940e8436d766a6..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/partitioning.py +++ /dev/null @@ -1,1147 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Utilities for partitioning.""" - -import abc -import collections -import dataclasses -import functools -import typing -from typing import Any, Callable, Optional, Sequence, Set, Tuple, Union - -from absl import logging -import cached_property -from flax import traverse_util -from flax.linen import partitioning as flax_partitioning -import jax -from jax import numpy as jnp -from jax import random -from jax.experimental import multihost_utils -from jax.experimental.mesh_utils import create_hybrid_device_mesh -from jax.experimental.pjit import pjit -from jax.interpreters import pxla -from jax.sharding import Mesh -from jax.sharding import PartitionSpec -import numpy as np -from t5x import train_state as train_state_lib - -JaxDevice = jax.Device -TpuMesh = Tuple[int, int, int, int] # (x, y, z, num_cores). -OtherMesh = Tuple[int, int] -HardwareMesh = Union[TpuMesh, OtherMesh] -TrainState = train_state_lib.TrainState -LogicalAxisRules = Sequence[Tuple[str, Optional[str]]] - -if typing.TYPE_CHECKING: # See b/163639353 - cached_property = property # pylint: disable=invalid-name -else: - cached_property = cached_property.cached_property - - -class AxisNames(tuple): - """Tuple of strings specifying name for each axis. - - We create a separate class for this so JAX's pytree utilities can distinguish - it from a tuple that should be treated as a pytree, instead treating it as a - leaf. - """ - - def __new__(cls, *names): - return tuple.__new__(AxisNames, names) - - def __repr__(self): - return 'AxisNames%s' % tuple.__repr__(self) - - -def with_sharding_constraint(x, axis_resources): - """Wrapper for lax.with_sharding_constraint, no-op on cpu or outside pjit.""" - if jax.devices()[0].platform == 'cpu' or not global_mesh_defined(): - return x - else: - return jax.lax.with_sharding_constraint(x, axis_resources) - - -# pjit Mesh creation functions. -# ----------------------------------------------------------------------------- -def bounds_from_last_device(last_device: jax.Device) -> HardwareMesh: - """Get the bound from the given last device.""" - # Must be passed the device at the highest-coordinate corner of the - # relevant mesh, which is a requirement we know is satisfied by the last - # device in jax.devices(). - if hasattr(last_device, 'coords') and len(last_device.coords) == 3: - x, y, z = last_device.coords - return x + 1, y + 1, z + 1, last_device.core_on_chip + 1 - else: - # On non-TPU platforms, the "mesh" is hosts x devices per host in order - # to take advantage of faster within-host interconnect. - return jax.process_count(), jax.local_device_count() - - -def get_coords(device: jax.Device) -> HardwareMesh: - """Returns the coordinates of the given device.""" - if hasattr(device, 'coords'): - return (*device.coords, device.core_on_chip) - return (device.process_index, device.id % jax.local_device_count()) - - -def global_mesh_defined(): - """Checks if global xmap/pjit mesh resource environment is defined.""" - maps_env = pxla.thread_resources.env - return maps_env.physical_mesh.devices.shape != () # pylint: disable=g-explicit-bool-comparison - - -def get_mesh(model_parallel_submesh: HardwareMesh, - input_devices: Sequence[JaxDevice] = (), - input_local_devices: Sequence[JaxDevice] = (), - tile_by_host_if_needed: bool = True, - backend: Optional[str] = None) -> Mesh: - """Construct an xmap/pjit Mesh for the given model-parallel submesh. - - The resulting mesh has two resource axes: 'model', with the provided submesh - shape, and 'data', which covers the rest of the mesh. - - Args: - model_parallel_submesh: a HardwareMesh spec, namely (x,y,z,core) on TPU for - a single model-parallel replica's "tile" in the physical device mesh. The - first three elements (`x`, `y`, and `z`) should be factors of the pod - slice; e.g., if you are using df_4x8, then `x` should be a factor of 4 - (one of 1, 2, 4), `y` should be a factor of 8 (one of 1, 2, 4, 8), and `z` - must be 1, because TPU v3 slices are only 2D. `z` can be >1 for TPU v4 - (and maybe later TPUs) that allow 3D slices. `core` is the number of cores - to use from each TPU node. As communication is usually fastest inside the - same node, if you need a tile of more than 1 core, then - you should first increase `core`: e.g., for TPU v3, (1,1,1,2) is better - than (2,1,1,1). To pick a good spec, try a few possible values until you - get high TPU utilization. - input_devices: the devices to use, will use jax.devices() if this is not - set. - input_local_devices: the local devices to use, will use jax.local_devices() - if this is not set. - tile_by_host_if_needed: JAX currently requires that the parts of any sharded - array that are located on one host's local devices form a single - contiguous slice. A best effort will be made to achieve this without - "tiling" the device assignment over hosts (which can reduce XLA collective - performance). If this flag is True, then the device assignment will be - tiled over hosts if necessary to satisfy this constraint and create a - buildable mesh; if false, mesh construction will fail instead. - backend: get devices from the pinned backend, if specified. This is - useful for explicitly specifying the devices other than relying on - jax_platform_name. - - Returns: - A xmap / pjit Mesh containing the virtual device mesh with data, model axes. - """ - input_devices = input_devices or jax.devices(backend) - input_local_devices = input_local_devices or jax.local_devices(0, backend) - # Sort input_devices based on coords, as backends might not return devices - # in order. - last_device = sorted(input_devices, key=get_coords)[-1] - last_input_local_devices = sorted(input_local_devices, key=get_coords)[-1] - logging.info('last device coords : %r\nlast local device coords: %r', - get_coords(last_device), get_coords(last_input_local_devices)) - global_hardware_mesh = bounds_from_last_device(last_device) - mesh_ndim = len(global_hardware_mesh) - local_hardware_mesh = bounds_from_last_device(last_input_local_devices) - mesh_err = ( - f'each dimension of the model parallel submesh {model_parallel_submesh} ' - 'must be a factor of the corresponding dimension of the global device ' - f'mesh {global_hardware_mesh}') - assert not any( - g % m - for g, m in zip(global_hardware_mesh, model_parallel_submesh)), mesh_err - assert not any( - g % l for g, l in zip(global_hardware_mesh, local_hardware_mesh)) - devices = np.empty(global_hardware_mesh, dtype=object) - for device in input_devices: - device_coords = get_coords(device) - devices[device_coords] = device - tile_by_host = tile_by_host_if_needed - if len(global_hardware_mesh) == 4: - # enable contiguous local chunks without host tiling by making Z major - global_hardware_mesh = typing.cast(Tuple[int, int, int, int], - global_hardware_mesh) - model_parallel_submesh = typing.cast(Tuple[int, int, int, int], - model_parallel_submesh) - gx, gy, gz, gc = global_hardware_mesh - mx, my, mz, mc = model_parallel_submesh - if (mx == gx > 1 and my == mz == 1) or (mx == 1 and my == gy > 1 and - mz == gz > 1): - logging.info('ensuring YZ plane has a Z-major device order') - # YZ should be ZY - assert mc == gc, (mc, gc) - global_hardware_mesh = gx, gz, gy, gc - model_parallel_submesh = mx, mz, my, mc - devices = devices.swapaxes(1, 2) - tile_by_host = False - if (my == gy > 1 and mx == mz == 1) or (my == 1 and mx == gx > 1 and - mz == gz > 1): - logging.info('ensuring XZ plane has a Z-major device order') - # XZ should be ZX - assert mc == gc, (mc, gc) - global_hardware_mesh = gz, gy, gx, gc - model_parallel_submesh = mz, my, mx, mc - devices = devices.swapaxes(0, 2) - tile_by_host = False - if tile_by_host: - logging.warning( - 'Tiling device assignment mesh by hosts, which may lead to ' - 'reduced XLA collective performance. To avoid this, modify ' - 'the model parallel submesh or run with more tasks per host.') - tile_err = ( - 'to tile the mesh by hosts, each dimension of the model parallel ' - 'submesh must be either a factor or a multiple of the corresponding ' - 'dimension of the per-host submesh') - - def dh_dd_mh_md(g: int, m: int, l: int) -> Tuple[int, int, int, int]: - """Split a global mesh dimension into four tiling components. - - Args: - g: global mesh bounds dimension size - m: model-parallel submesh bounds dimension size - l: local submesh bounds dimension size - - Returns: - The resulting tuple divides the dimension into the hosts component of - the data-parallel submesh, the devices component of the data-parallel - submesh, the hosts component of the model-parallel submesh, and the - devices component of the model-parallel submesh. - """ - d = g // m - if m >= l: - assert not m % l, tile_err - return (d, 1, m // l, l) - else: - assert not l % m, tile_err - return (d // (l // m), l // m, 1, m) - - # e.g. [(x_data_hosts, x_data_devs, x_model_hosts, x_model_devs), ...] - dh_dd_mh_md_tups = map(dh_dd_mh_md, global_hardware_mesh, - model_parallel_submesh, local_hardware_mesh) - # reshape to e.g. (x_dh, x_dd, x_mh, x_md, y_dh, ...) - devices = devices.reshape(*(s for t in dh_dd_mh_md_tups for s in t)) # pylint: disable=g-complex-comprehension - # TODO(jekbradbury): reorder local subgroups for ring locality - # Transpose to [data_host], [data_device], [model_host], [model_device] - # block ordering e.g. (x_dh, y_dh, ..., x_dd, y_dd, ...) - devices = devices.transpose(*(4 * i for i in range(mesh_ndim)), - *(4 * i + 1 for i in range(mesh_ndim)), - *(4 * i + 2 for i in range(mesh_ndim)), - *(4 * i + 3 for i in range(mesh_ndim))) - else: - # e.g. [(x_data, x_model), (y_data, y_model), ...] - model_data_tups = [ - (g // m, m) - for g, m in zip(global_hardware_mesh, model_parallel_submesh) - ] - # reshape to e.g. (x_data, x_model, y_data, y_model...) - devices = devices.reshape(*(s for t in model_data_tups for s in t)) # pylint: disable=g-complex-comprehension - # TODO(jekbradbury): reorder small subgroups for ring locality - # transpose to e.g. (x_data, y_data, ..., x_model, ...) - devices = devices.transpose(*(2 * i for i in range(mesh_ndim)), - *(2 * i + 1 for i in range(mesh_ndim))) - # reshape to (data, model) - devices = devices.reshape(-1, np.prod(model_parallel_submesh)) - global_mesh = Mesh(devices, ['data', 'model']) - logging.info('global_mesh axis_names: %s', global_mesh.axis_names) - logging.info('global_mesh devices: %s', global_mesh.devices) - logging.info('global_mesh devices shape: %s', global_mesh.devices.shape) - return global_mesh - - -def get_cpu_mesh() -> Mesh: - """Trivial mesh for CPU Testing.""" - devices = np.empty( - (jax.process_count(), jax.local_device_count()), dtype=object - ) - for device in jax.devices(): - devices[device.process_index, device.id % jax.local_device_count()] = device - return Mesh(devices, ['data', 'model']) - - -def get_gpu_mesh(num_partitions: int) -> Mesh: - """Mesh for GPUs that preferentially places 'model' on NVLink.""" - nvlink_size = jax.local_device_count() - dcn_size = jax.process_count() - nvlink_mp = min(num_partitions, nvlink_size) - nvlink_dp, extra1 = divmod(nvlink_size, nvlink_mp) - dcn_mp, extra2 = divmod(num_partitions, nvlink_mp) - assert not (extra1 or extra2), ('number of partitions on GPU must be a factor' - ' or multiple of the number of local devices') - dcn_dp = dcn_size // dcn_mp - - devices = create_hybrid_device_mesh( - mesh_shape=[nvlink_dp, nvlink_mp], - dcn_mesh_shape=[dcn_dp, dcn_mp], - process_is_granule=True) - - global_mesh = Mesh(devices, ['data', 'model']) - logging.info('global_mesh axis_names: %s', global_mesh.axis_names) - logging.info('global_mesh devices: %s', global_mesh.devices) - return global_mesh - - -def default_mesh( - num_partitions: int, - model_parallel_submesh: Optional[HardwareMesh] = None, - backend: Optional[str] = None, - ici_mesh_shape: Optional[HardwareMesh] = None, - dcn_mesh_shape: Optional[HardwareMesh] = None, -) -> Mesh: - """Attempt to return a default mesh for simple cases. - - Args: - num_partitions: number of partitions to use, will be ignored if - model_parallel_submesh is provided. - model_parallel_submesh: 4-tuple that specifies the x,y,z,c submesh to use as - the model-parallel device tile. - backend: get devices from the pinned backend, if specified. This is useful - for explicitly specifying the devices other than relying on - jax_platform_name. - ici_mesh_shape: Shape of the logical mesh used for SPMD parallelism in each - slice. The meaning of each mesh axis is defined by mesh_axis_names, so - these two params must be the same length. If dcn_mesh_shape is present, - the overall mesh is the product of ici_mesh_shape and dcn_mesh_shape. For - example, an ici_mesh_shape of [2, 3, 4] with mesh_axis_names ['replica', - 'data', 'model'] indicates 2-way replica parallelism, 3-way data - parallelism, and 4-way model parallelism over 24 devices. None, the - default, is equivalent to a sequence of ones and means that the model is - placed on a single device. - dcn_mesh_shape: Shape of the logical mesh used for SPMD parallelism over - multiple slices. The overall mesh is the product of ici_mesh_shape and - dcn_mesh_shape, and the meaning of each mesh axis is defined by - mesh_axis_names, so these three params must be the same length. - - Returns: - xmap/pjit 2D Mesh with 'data', 'model' mesh axes if single-slice, otherwise - 3D Mesh with 'replica', 'data', and 'model' mesh axes. - """ - devices = jax.devices(backend) - last_device = devices[-1] - platform = last_device.platform - device_kind = last_device.device_kind - bounds = bounds_from_last_device(last_device) - - if ici_mesh_shape is not None and dcn_mesh_shape is not None: - device_mesh = create_hybrid_device_mesh( - ici_mesh_shape, - dcn_mesh_shape, - devices=devices, - ) - multi_slice_global_mesh = Mesh(device_mesh, ['replica', 'data', 'model']) - logging.info( - 'multi_slice_global_mesh axis_names: %s', - multi_slice_global_mesh.axis_names, - ) - logging.info( - 'multi_slice_global_mesh devices: %s', multi_slice_global_mesh.devices - ) - logging.info( - 'multi_slice_global_mesh devices shape: %s', - multi_slice_global_mesh.devices.shape, - ) - return multi_slice_global_mesh - - if model_parallel_submesh: - return get_mesh(model_parallel_submesh, backend=backend) - - if platform == 'cpu': - return get_cpu_mesh() - elif platform == 'gpu': - return get_gpu_mesh(num_partitions) - - mps = None - if device_kind in ('TPU v2', 'TPU v3'): - if num_partitions == 1: - mps = (1, 1, 1, 1) - elif num_partitions == 2: - mps = (1, 1, 1, 2) - elif num_partitions == 4: - mps = (2, 1, 1, 2) - elif num_partitions == 8: - mps = (2, 2, 1, 2) - elif num_partitions == 16: - mps = (4, 2, 1, 2) - # assume the use of megacore on TPU v4 - elif (device_kind == 'TPU v4' or - device_kind == 'TPU v4 lite') and bounds[3] == 1: - if num_partitions == 1: - mps = (1, 1, 1, 1) - elif num_partitions == 2: - mps = (1, 2, 1, 1) - elif num_partitions == 4: - if bounds[0] >= 4: - mps = (4, 1, 1, 1) - else: - mps = (2, 2, 1, 1) - elif num_partitions == 8: - if bounds[2] >= 8: - mps = (1, 1, 8, 1) - else: - mps = (4, 2, 1, 1) - elif num_partitions == 16: - if bounds[2] >= 16: - mps = (1, 1, 16, 1) - elif bounds[0] >= 8: - mps = (8, 2, 1, 1) - elif bounds[0] >= 4: - mps = (4, 4, 1, 1) - else: - mps = (2, 2, 4, 1) - - if mps is None: - raise ValueError( - 'No default mesh for this configuration: specify ' - 'config.model_parallel_submesh explicitly. \n' - f'Platform: {platform}\n' - f'Device kind: {device_kind}\n' - f'Num partitions: {num_partitions}\n' - f'Bounds: {bounds}' - ) - return get_mesh(mps, backend=backend) - - -# Data chunking helper. -# ----------------------------------------------------------------------------- -@dataclasses.dataclass -class LocalChunkInfo: - # The logical slice of an array located on this host's local devices. - slice: Tuple[slice, ...] - # A unique index for this host/local chunk among chunks with the same slice. - replica_id: int - - -class LocalChunker: - """Utility class to aid chunking of sharded arrays in multihost settings.""" - - def __init__(self, global_mesh: Mesh): - self.global_mesh = global_mesh - local_mesh = global_mesh.local_mesh - first_local_device = local_mesh.devices.reshape(-1)[0] - host_location = collections.OrderedDict( - zip( - global_mesh.shape.keys(), - list(zip(*np.nonzero( - global_mesh.devices == first_local_device)))[0])) - self.num_chunks = collections.OrderedDict() - self.chunk_ids = collections.OrderedDict() - self.mesh_axes = list(global_mesh.shape.keys()) - for mesh_axis in self.mesh_axes: - num_devices_per_chunk = local_mesh.shape[mesh_axis] - self.num_chunks[mesh_axis] = ( - global_mesh.shape[mesh_axis] // num_devices_per_chunk) - self.chunk_ids[mesh_axis] = ( - host_location[mesh_axis] // num_devices_per_chunk) - - def get_local_chunk_info( - self, global_shape: Tuple[int, ...], - mesh_axes: Sequence[Optional[str]]) -> LocalChunkInfo: - """Get the local chunk info for a given array shape and sharded axes. - - Args: - global_shape: the global, unsharded shape of the array to chunk. - mesh_axes: a sequence of names (or None) of equal rank to `global_shape` - that specifies which mesh dimensions the array is sharded along. - - Returns: - LocalChunkInfo containing the logical slices of the array found on this - host's local devices, as well as the replica index for this chunk among - chunks with the same slice. The latter is used to determine which - host should write this chunk during checkpointing. - """ - local_slice = [slice(None) for dim in global_shape] - sharded_mesh_axes = set() - for i, (mesh_axis, size) in enumerate(zip(mesh_axes, global_shape)): - if not mesh_axis: - continue - sharded_mesh_axes.add(mesh_axis) - if not isinstance(mesh_axis, str): - raise NotImplementedError('TODO(jekbradbury)') - chunk_id = self.chunk_ids[mesh_axis] - chunk_size = size // self.num_chunks[mesh_axis] - local_slice[i] = slice(chunk_id * chunk_size, (chunk_id + 1) * chunk_size) - - replica_id = self.get_replica_id(sharded_mesh_axes) - - return LocalChunkInfo(tuple(local_slice), replica_id) - - def get_shard_id(self, sharded_mesh_axes: str | Set[Optional[str]]) -> int: - """Given mesh axes used for sharding, computes current host's shard id. - - To give an example, let's say there are two axes globally: replica, data, - and model, the mesh axes for sharding is ('replica', 'data'), which means we - are going to partition an array along 'replica' and 'data' axes. - The shard_id is to show the index of the current local host along the - sharding axes (in this example, it's 'replica' and 'data' axes). - - More concretely, let's say we have 4 local hosts, and we use 'replica' and - 'data' axes for data parallel (2 hosts along the replica axis, and 2 host - along the data axis). The host located in ('replica': 0, 'data': 0), we - should assign data shard-0 to it. For host ('replica': 0, 'data': 1), we - assign shard-1. For host ('replica': 1, 'data': 0), we assign shard-2. - For host ('replica': 1, 'data': 1), we assign shard-3. - - Note: the host location along 'replica' and 'data' axes, e.g., - ('replica': 0, 'data': 0) is named chunk_id and stored in - self._local_chunker.chunk_ids[axis]. - - Args: - sharded_mesh_axes: the mesh axes for sharding. - - Returns: - the index of the current local host along the sharding axes. - """ - if isinstance(sharded_mesh_axes, str): - sharded_mesh_axes = (sharded_mesh_axes,) - - shard_id = 0 - for mesh_axis in sharded_mesh_axes: - chunk_id = self.chunk_ids[mesh_axis] - shard_id = shard_id * self.num_chunks[mesh_axis] + chunk_id - - return shard_id - - def get_replica_id(self, sharded_mesh_axes: str | Set[Optional[str]]) -> int: - """Given mesh axes used for sharding, computes current host's replica id. - - To give an example, let's say there are two axes globally: data, and model, - the mesh axes for sharding is ('data', ), which means we are going to - partition an array along 'data' axis and replicate it along 'model' axis. - The replica_id is to show the index of the current local host along the - 'model' axis. - - Args: - sharded_mesh_axes: the mesh axes for sharding. - - Returns: - the index of the current local host along the non-sharding axes (i.e., - replicating axes). - """ - if isinstance(sharded_mesh_axes, str): - sharded_mesh_axes = (sharded_mesh_axes,) - - replicated_mesh_axes = [ - mesh_axis for mesh_axis in self.mesh_axes - if mesh_axis not in sharded_mesh_axes - ] - replica_id = 0 - for mesh_axis in replicated_mesh_axes: - chunk_id = self.chunk_ids[mesh_axis] - replica_id = replica_id * self.num_chunks[mesh_axis] + chunk_id - - return replica_id - - -def standard_logical_axis_rules( - activation_partitioning_dims: int = 1, - parameter_partitioning_dims: int = 1, - additional_rules: Optional[LogicalAxisRules] = None) -> LogicalAxisRules: - """Default sharding rules for T5X model in terms of logical axis names. - - Args: - activation_partitioning_dims: enables 2-D activation sharding when set to 2. - parameter_partitioning_dims: enables 2-D parameter sharding when set to 2. - additional_rules: additional rules (a sequence of tuples) that will be - appended to the standard rules. - - Returns: - Sequence of logical axis rules - """ - logging.info( - '`activation_partitioning_dims` = %d, `parameter_partitioning_dims` = %d', - activation_partitioning_dims, parameter_partitioning_dims) - - if activation_partitioning_dims == 1 and parameter_partitioning_dims == 1: - rules = [ - ('batch', 'data'), - ('vocab', 'model'), - ('embed', None), - ('mlp', 'model'), - ('heads', 'model'), - ('kv', None), - ('joined_kv', 'model'), # joined heads+kv dim in 2D attn param layouts - ] - elif activation_partitioning_dims == 2 and parameter_partitioning_dims == 1: - rules = [ - ('batch', 'data'), - ('vocab', 'model'), - ('mlp', 'model'), - ('heads', 'model'), - ('kv', None), - ('joined_kv', 'model'), - ('embed', 'model'), - ] - elif activation_partitioning_dims == 1 and parameter_partitioning_dims == 2: - rules = [ - ('batch', 'data'), - ('vocab', 'model'), - ('mlp', 'model'), - ('heads', 'model'), - ('kv', None), - ('joined_kv', 'model'), - ('embed', 'data'), - ] - elif activation_partitioning_dims == 2 and parameter_partitioning_dims == 2: - rules = [ - ('batch', 'data'), - ('vocab', 'model'), - ('mlp', 'model'), - ('heads', 'model'), - ('kv', None), - ('joined_kv', 'model'), - ('embed', 'model'), - ('embed', 'data'), - ] - else: - raise ValueError( - f'`activation_partitioning_dims` = {activation_partitioning_dims} ' - f'`parameter_partitioning_dims` = {parameter_partitioning_dims} ' - 'is not supported.') - - # Add the common rules for the replicated logical axes names. - replicated_rules = [ - ('relpos_buckets', None), - ('abspos_buckets', None), - ('length', None), - ('layers', None), - ('stack', None), - ('mlp_activations', None), - ] - rules.extend(replicated_rules) - - if additional_rules: - rules.extend(additional_rules) - - return rules - - -# NB: This needs to be top-level for the jax compilation cache. -def _id_fn(x, ix): - """Identity function for copying parameters to the devices, sharded.""" - # A pure identity such as `lambda x, *: x` can get optimized away, so we - # include a random.split as a cheap function that cannot be optimized away. - y = random.split(random.PRNGKey(jnp.array(ix, dtype=jnp.uint32))) - return x, y - - -@dataclasses.dataclass -class DataLayout: - """Represents data layout for the partitioned model.""" - batch_size: int - shard_id: int - num_shards: int - is_first_host_in_replica_set: bool - - -PartitionedCallable = Callable[..., Any] -CompiledPartitionedCallable = Callable[..., Any] - - -class BasePartitioner(metaclass=abc.ABCMeta): - """Interface for partitioning computations across hardware devices.""" - - def __init__( - self, - num_partitions: Optional[int] = None, - model_parallel_submesh: Optional[HardwareMesh] = None, - params_on_devices: bool = True, - backend: Optional[str] = None, - ici_mesh_shape: Optional[HardwareMesh] = None, - dcn_mesh_shape: Optional[HardwareMesh] = None, - ): - """Configures the partitioner. - - Args: - num_partitions: the number of partitions to use. Ignored if - `model_parallel_submesh` is provided. - model_parallel_submesh: 4-tuple that specifies the x,y,z,c submesh to use - as the model-parallel device tile. This submesh is used for the larger - of the two parameter dimensions, and, if 2-D activation sharding is - enabled, for the model dimension of activations. The rest of the mesh is - used for data parallelism and, if 2-D parameter sharding is enabled, the - other parameter dimension. - params_on_devices: whether to keep the params on devices, if False - - params stay in the host memory. Note that some partitioners might ignore - this setting, for example if they don't support storing all params on - device memory. - backend: get devices from the pinned backend, if specified. This is useful - for explicitly specifying the devices other than relying on - jax_platform_name. - ici_mesh_shape: Shape of the logical mesh used for SPMD parallelism in - each slice. The meaning of each mesh axis is defined by mesh_axis_names, - so these two params must be the same length. If dcn_mesh_shape is - present, the overall mesh is the product of ici_mesh_shape and - dcn_mesh_shape. For example, an ici_mesh_shape of [2, 3, 4] with - mesh_axis_names ['replica', 'data', 'mdl'] indicates 2-way replica - parallelism, 3-way data parallelism, and 4-way model parallelism over 24 - devices. None, the default, is equivalent to a sequence of ones and - means that the model is placed on a single device. - dcn_mesh_shape: Shape of the logical mesh used for SPMD parallelism over - multiple slices. The overall mesh is the product of ici_mesh_shape and - dcn_mesh_shape, and the meaning of each mesh axis is defined by - mesh_axis_names, so these three params must be the same length. - """ - - if not num_partitions and not model_parallel_submesh: - raise ValueError('At least one of `num_partitions` or ' - '`model_parallel_submesh` must be set.') - - if model_parallel_submesh is not None and len(model_parallel_submesh) != 4: - logging.error( - ( - '`model_parallel_submesh` must be either None or a 4-tuple. Got' - ' `model_parallel_submesh`=%r. A ValueError will be raised' - ' beginning March 1, 2022.' - ), - model_parallel_submesh, - ) - - if bool(num_partitions) and bool(model_parallel_submesh): - logging.error( - 'At most one of `num_partitions` or `model_parallel_submesh` can be ' - 'set. Got `num_partitions=%r` and `model_parallel_submesh`=%r. A ' - 'ValueError will be raised beginning March 21, 2022.', - num_partitions, - model_parallel_submesh, - ) - - self._num_partitions = num_partitions - self._model_parallel_submesh = model_parallel_submesh - self._params_on_devices = params_on_devices - if ici_mesh_shape is None or dcn_mesh_shape is None: - self._data_axis = 'data' - else: - self._data_axis = ('replica', 'data') - self._backend = backend - self._ici_mesh_shape = ici_mesh_shape - self._dcn_mesh_shape = dcn_mesh_shape - - @property - def mesh(self) -> Mesh: - raise NotImplementedError - - @property - def data_partition_spec(self) -> PartitionSpec: - return PartitionSpec(self._data_axis) - - @property - def data_mesh_size(self) -> int: - """Data mesh size. - - Data mesh size is defined as the number of global devices involved to - carry out data parallel. Let's say we have a global mesh: ('replica': 2, - 'data': 4, 'model': 2), and axes 'replica' and 'data' are responsible for - the data parallel, that means we have 2*4 = 8 devices involved - i.e., data - mesh size is 8. - - Returns: - the id of the shard for the axes being replicated among the devices used - to shard the sharded_mesh_axes. - """ - data_submesh_sizes = ( - [self.mesh.shape[self._data_axis]] - if isinstance(self._data_axis, str) - else [self.mesh.shape[axis] for axis in self._data_axis] - ) - data_mesh_size = functools.reduce(lambda x, y: x * y, data_submesh_sizes) - return data_mesh_size - - @property - def data_shards(self) -> int: - """Number of data shards. - - Let's say we are dealing with 2 slices of df4x2 TPUs. In data pipeline - we need prepare / send one data shard to each local host. This means, we - need 4 shards since we have 4 local hosts. How to infer the number of hosts - from mesh information? In this case, we have a global mesh: ('replica': 2, - 'data': 8, 'model': 2). Each local host (i.e., df2x2) has this local mesh: - ('replica': 1, 'data': 4, 'model': 2). By dividing global mesh with local - mesh, we can get the count of hosts. - - Returns: - Number of data shards. Each shard will be sent to one local host. - """ - data_chunks = ( - [self._local_chunker.num_chunks[self._data_axis]] - if isinstance(self._data_axis, str) - else [self._local_chunker.num_chunks[axis] for axis in self._data_axis] - ) - data_shards = functools.reduce(lambda x, y: x * y, data_chunks) - return data_shards - - @property - def data_shard_id(self) -> int: - """Data shard id for the current host. - - Returns: - Index of data shard that will be sent to the current local host. - """ - return self._local_chunker.get_shard_id(self._data_axis) - - def get_data_layout( - self, batch_size: Optional[int] = None, host_index: Optional[int] = None - ) -> DataLayout: - """Returns filled `DataLayout` based on the partitioned model layout. - - Args: - batch_size: if set, indicates the requested batch size. The exception will - be raised if this batch size is not compatible with the layout. If not - set, the batch size is inferred from the layout. - host_index: indicates the host index to use for the calculations, if not - set - use JAX-provided one. Should be in [0, num_hosts) interval and the - order should match the order of corresponding CPU devices in - `jax.devices()`. - - Returns: - Filled `DataLayout` structure. - """ - if host_index is not None: - raise NotImplementedError('Explicit host_index is not yet implemented.') - if self._data_axis is None: - return DataLayout( - batch_size=batch_size, - shard_id=0, - num_shards=1, - is_first_host_in_replica_set=(jax.process_index() == 0)) - - batch_size = batch_size or self.data_mesh_size - if batch_size % self.data_mesh_size: - raise ValueError( - f'Batch size ({batch_size}) must be divisible by corresponding ' - f'data mesh size ({self.data_mesh_size}).' - ) - - if batch_size % self.data_shards: - raise ValueError( - f'Batch size ({batch_size}) must be divisible by number of ' - f'data shards ({self.data_shards}).' - ) - replica_id = self._local_chunker.get_replica_id(self._data_axis) - return DataLayout( - batch_size=int(batch_size), - shard_id=int(self.data_shard_id), - num_shards=int(self.data_shards), - is_first_host_in_replica_set=(replica_id == 0), - ) - - def get_local_chunk_info( - self, global_shape: Tuple[int, ...], - mesh_axes: Sequence[Optional[str]]) -> LocalChunkInfo: - """Returns the local chunk info for a given array shape and sharded axes.""" - return self._local_chunker.get_local_chunk_info(global_shape, mesh_axes) - - @property - def params_on_devices(self): - return self._params_on_devices - - @params_on_devices.setter - def params_on_devices(self, value): - self._params_on_devices = value - - def move_params_to_devices(self, train_state: TrainState, - train_state_axes: TrainState) -> TrainState: - """Moves the optimizer parameters to devices.""" - p_id_fn = self.partition( - _id_fn, - in_axis_resources=(train_state_axes, None), - out_axis_resources=(train_state_axes, None), - donate_argnums=(0,)) - if jax.process_count() > 1: - train_state = host_local_array_to_global_array( - train_state, self.mesh, train_state_axes - ) - train_state, _ = p_id_fn(train_state, jnp.ones((), dtype=jnp.uint32)) - return train_state - - @property - @abc.abstractmethod - def _local_chunker(self): - """Returns the chunker that matches the parameters of this partitioner.""" - raise NotImplementedError - - def get_logical_axes(self, train_state: TrainState) -> TrainState: - """Returns a copy of TrainState with Optional[AxisNames] as leaves.""" - # By default, return None for the logical axes. - return train_state.restore_state( - jax.tree.map(lambda x: None, train_state.state_dict()) - ) - - def get_mesh_axes(self, train_state: TrainState) -> TrainState: - """Returns a copy of TrainState with Optional[PartitionSpecs] as leaves.""" - raise NotImplementedError - - @abc.abstractmethod - def partition( - self, - fn: Callable, # pylint: disable=g-bare-generic - in_axis_resources, - out_axis_resources, - static_argnums: Union[int, Sequence[int]] = (), - donate_argnums: Union[int, Sequence[int]] = () - ) -> PartitionedCallable: - """Partitions the computation using partitioner-specific implementation. - - Args: - fn: the function to partition. - in_axis_resources: Pytree of structure matching that of arguments to `fn`, - with all actual arguments replaced by resource assignment - specifications. It is also valid to specify a pytree prefix (e.g. one - value in place of a whole subtree), in which case the leaves get - broadcast to all values in that subtree. - The valid resource assignment specifications are: - `None`: in which case the value will be replicated on all devices - `PartitionSpec`: a tuple of length at most equal to the rank of the - partitioned value. Each element can be a `None`, a mesh axis or a - tuple of mesh axes, and specifies the set of resources assigned to - partition the value's dimension matching its position in the spec. - out_axis_resources: Like `in_axis_resources`, but specifies resource - assignment for function outputs. - static_argnums: an optional int or collection of ints that specify which - positional arguments to treat as static (compile-time constant) in the - partitioned function. - donate_argnums: an optional int or collection of ints that specify which - argument buffers are "donated" to the computation. It is safe to donate - argument buffers if you no longer need them once the computation has - finished. - - Returns: - A partitioned version of the input function. - """ - raise NotImplementedError - - @abc.abstractmethod - def compile(self, partitioned_fn: PartitionedCallable, - *args) -> CompiledPartitionedCallable: - """Compiles and returns the partitioned function, or the original. - - Args: - partitioned_fn: The partitioned function. - *args: Sample arguments to the partitioned function matching the input - shapes that will be passed to the compiled function. - - Returns: - The compiled function, or the original if this partitioner does not - support compilation. - """ - raise NotImplementedError - - -class PjittedFnWithContext(PartitionedCallable): - """Wraps pjitted function to apply the appropriate contexts.""" - - def __init__(self, - pjitted_fn, - partition_mesh: Mesh, - logical_axis_rules: flax_partitioning.LogicalRules = ()): - self._pjitted_fn = pjitted_fn - self._mesh = partition_mesh - self._logical_axis_rules = logical_axis_rules - - def __call__(self, *args, **kwargs): - with Mesh(self._mesh.devices, - self._mesh.axis_names), flax_partitioning.axis_rules( - self._logical_axis_rules): - return self._pjitted_fn(*args, **kwargs) - - def lower(self, *args, **kwargs): - with Mesh(self._mesh.devices, - self._mesh.axis_names), flax_partitioning.axis_rules( - self._logical_axis_rules): - return self._pjitted_fn.lower(*args, **kwargs) - - -class BasePjitPartitioner(BasePartitioner): - """Partitioner that uses T5X version of jax.pjit.""" - - @cached_property - def _local_chunker(self) -> LocalChunker: - return LocalChunker(self.mesh) - - @cached_property - def mesh(self) -> Mesh: - return default_mesh( - self._num_partitions, - self._model_parallel_submesh, - self._backend, - self._ici_mesh_shape, - self._dcn_mesh_shape, - ) - - def partition( - self, - fn: Callable, # pylint: disable=g-bare-generic - in_axis_resources, - out_axis_resources, - static_argnums: Union[int, Sequence[int]] = (), - donate_argnums: Union[int, Sequence[int]] = (), - ) -> PjittedFnWithContext: - pjitted = pjit( - fn, - in_shardings=in_axis_resources, - out_shardings=out_axis_resources, - static_argnums=static_argnums, - donate_argnums=donate_argnums, - ) - - return PjittedFnWithContext(pjitted, self.mesh) - - def compile(self, partitioned_fn: PjittedFnWithContext, - *args) -> CompiledPartitionedCallable: - return partitioned_fn.lower(*args).compile() - - -class PjitPartitioner(BasePjitPartitioner): - """Partitioner that uses named axes and jax.pjit.""" - - def __init__( - self, - num_partitions: Optional[int] = None, - model_parallel_submesh: Optional[HardwareMesh] = None, - params_on_devices: bool = True, - backend: Optional[str] = None, - ici_mesh_shape: Optional[HardwareMesh] = None, - dcn_mesh_shape: Optional[HardwareMesh] = None, - logical_axis_rules: Optional[LogicalAxisRules] = None, - ): - """PjitPartitioner constructor. - - See https://github.com/google-research/text-to-text-transfer-transformer/blob/main/README.mdx/usage/partitioning for details. - - Args: - num_partitions: an integer that specifies the size of the model parallel - submesh to be automatically selected for the current topology. See - `model_parallel_submesh` for details on how this submesh is used. - Mutually exclusive with `model_parallel_submesh`. - model_parallel_submesh: is a 4-tuple that specifies the `(x, y, z, c)` - submesh model-parallel device tile, an axis of accelerator parallelism - orthogonal to data parallelism. Array axes in a model's parameters or - activations can be sharded over this submesh using axis rules (see - `logical_axis_rules`) that map them to 'model'. The effective number of - model sub-partitions is equal to `np.prod(model_parallel_submesh)` and - must evenly divide the total number of devices (i.e., - `jax.device_count() % np.prod(model_parallel_submesh) == 0`). The rest - of the TPU mesh is the data parallel submesh, providing - `jax.device_count() // np.prod(model_parallel_submesh)` partitions. It - is used for data (batch) parallelism and to shard other array axes that - are mapped to 'data'. This argument is mutually exclusive with - `num_partitions`. - params_on_devices: whether to keep the params on devices, if False - - params stay in the host memory. Note that some partitioners might ignore - this setting, for example if they don't support storing all params on - device memory. - backend: get devices from the pinned backend, if specified. This is useful - for explicitly specifying the devices other than relying on - jax_platform_name. - ici_mesh_shape: Shape of the logical mesh used for SPMD parallelism in - each slice. The meaning of each mesh axis is defined by mesh_axis_names, - so these two params must be the same length. If dcn_mesh_shape is - present, the overall mesh is the product of ici_mesh_shape and - dcn_mesh_shape. For example, an ici_mesh_shape of [2, 3, 4] with - mesh_axis_names ['replica', 'data', 'model'] indicates 2-way replica - parallelism, 3-way data parallelism, and 4-way model parallelism over 24 - devices. None, the default, is equivalent to a sequence of ones and - means that the model is placed on a single device. - dcn_mesh_shape: Shape of the logical mesh used for SPMD parallelism over - multiple slices. The overall mesh is the product of ici_mesh_shape and - dcn_mesh_shape, and the meaning of each mesh axis is defined by - mesh_axis_names, so these three params must be the same length. - logical_axis_rules: a priority-ordered sequence of KV tuples that maps - logical axis names to either `None` (not sharded), 'model' (to shard - across the model-parallel submesh), or 'data' (to shard across the - data-parallel submesh). - """ - super().__init__( - num_partitions=num_partitions, - model_parallel_submesh=model_parallel_submesh, - params_on_devices=params_on_devices, - backend=backend, - ici_mesh_shape=ici_mesh_shape, - dcn_mesh_shape=dcn_mesh_shape, - ) - if logical_axis_rules is None: - logical_axis_rules = standard_logical_axis_rules() - if ici_mesh_shape is not None and dcn_mesh_shape is not None: - # Split batch over new replica axis. - logical_axis_rules = ( - (k, ('replica', 'data') if k == 'batch' else v) - for k, v in logical_axis_rules - ) - self._logical_axis_rules = tuple(logical_axis_rules) - (self._data_axis,) = flax_partitioning.logical_to_mesh_axes( - ['batch'], self._logical_axis_rules - ) - - def partition( - self, - fn: Callable, # pylint: disable=g-bare-generic - in_axis_resources, - out_axis_resources, - static_argnums: Union[int, Sequence[int]] = (), - donate_argnums: Union[int, Sequence[int]] = () - ) -> PjittedFnWithContext: - """Partitions the function using jax.pjit.""" - pjitted = pjit( - fn, - in_shardings=in_axis_resources, - out_shardings=out_axis_resources, - static_argnums=static_argnums, - donate_argnums=donate_argnums, - ) - - return PjittedFnWithContext(pjitted, self.mesh, self._logical_axis_rules) - - @property - def logical_axis_rules(self): - """Returns the logical axis rules.""" - return self._logical_axis_rules - - def get_logical_axes(self, train_state: TrainState) -> TrainState: - """Returns a copy of TrainState with Optional[AxisNames] as leaves.""" - return train_state.as_logical_axes() - - def get_mesh_axes(self, train_state: TrainState) -> TrainState: - """Returns a copy of TrainState with Optional[PartitionSpecs] as leaves.""" - logical_axes = self.get_logical_axes(train_state) - - def _logical_to_mesh_axes(param_name, logical_axes): - if logical_axes is None: - return None - elif logical_axes is traverse_util.empty_node: - return traverse_util.empty_node - try: - return flax_partitioning.logical_to_mesh_axes(logical_axes, - self._logical_axis_rules) - except ValueError as e: - raise ValueError(f'Failed to map logical axes for {param_name}') from e - - flat_logical_axes = traverse_util.flatten_dict( - logical_axes.state_dict(), keep_empty_nodes=True, sep='/') - flat_mesh_axes = { - k: _logical_to_mesh_axes(k, v) for k, v in flat_logical_axes.items() - } - - return logical_axes.restore_state( - traverse_util.unflatten_dict(flat_mesh_axes, sep='/')) - - -# arr_tree is a PyTree of jax.Array or np.ndarray and -# pspecs is PyTree[jax.sharding.PartitionSpec] -def host_local_array_to_global_array(arr_tree, mesh: jax.sharding.Mesh, pspecs): - pspecs = jax.tree.map( - lambda x: PartitionSpec() if x is None else x, - pspecs, - is_leaf=lambda x: x is None, - ) - return multihost_utils.host_local_array_to_global_array( - arr_tree, mesh, pspecs - ) diff --git a/t5x-main/t5x/partitioning_test.py b/t5x-main/t5x/partitioning_test.py deleted file mode 100644 index 199c27db624651e45bd8cbaafba41271652a0170..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/partitioning_test.py +++ /dev/null @@ -1,308 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for t5x.partitioning.""" - -import collections -import unittest - -from absl.testing import absltest -from absl.testing import parameterized -import flax.core -from flax.linen import partitioning as nn_partitioning -import jax -import numpy as np -from t5x import adafactor -from t5x import optimizers -from t5x import partitioning -from t5x import test_utils as ptu -from t5x import train_state - -jax.config.parse_flags_with_absl() - -mock = absltest.mock -TpuDevice = ptu.TpuDevice -TPUV3_32 = ptu.make_devices(4, 4, 1, 2, kind='TPU v3') -AxisMetadata = nn_partitioning.AxisMetadata -PartitionSpec = partitioning.PartitionSpec - - -class PartitioningTest(absltest.TestCase): - - @mock.patch('jax.process_count') - @mock.patch('jax.local_device_count') - def test_bounds_from_last_device(self, local_device_count, process_count): - last_device = mock.Mock(coords=(3, 3, 3), core_on_chip=1) - tpu_bounds = partitioning.bounds_from_last_device(last_device) - self.assertEqual(tpu_bounds, (4, 4, 4, 2)) - - last_device = mock.Mock(spec=[]) - process_count.return_value = 1 - local_device_count.return_value = 4 - non_tpu_bounds = partitioning.bounds_from_last_device(last_device) - self.assertEqual(non_tpu_bounds, (1, 4)) - - @mock.patch('jax.local_device_count') - def test_get_coords(self, local_device_count): - device = mock.Mock(coords=(1, 0, 1), core_on_chip=1) - coords = partitioning.get_coords(device) - self.assertEqual(coords, (1, 0, 1, 1)) - - device = mock.Mock(spec=['process_index', 'id']) - device.process_index = 1 - device.id = 9 - local_device_count.return_value = 8 - coords = partitioning.get_coords(device) - self.assertEqual(coords, (1, 1)) - - @unittest.skipIf(jax.__version_info__ < (0, 4, 5), 'Test requires jax 0.4.5') - @mock.patch('jax.local_devices') - @mock.patch('jax.devices') - @mock.patch(f'{jax.process_index.__module__}.process_index') - def test_default_mesh(self, process_index_fn, devices_fn, local_devices_fn): - devices_fn.return_value = TPUV3_32 - local_devices_fn.return_value = [ - d for d in TPUV3_32 if d.process_index == 0 - ] - process_index_fn.return_value = 0 - - global_mesh = partitioning.default_mesh(4) - self.assertEqual(global_mesh.axis_names, ('data', 'model')) - self.assertEqual( - global_mesh.shape, collections.OrderedDict((('data', 8), ('model', 4))) - ) - self.assertEqual(global_mesh.size, 32) - - for process_index in (0, 1, 2, 3): - process_index_fn.return_value = process_index - local_mesh = global_mesh.local_mesh - self.assertEqual(local_mesh.axis_names, ('data', 'model')) - self.assertEqual( - local_mesh.shape, collections.OrderedDict((('data', 2), ('model', 4))) - ) - self.assertEqual(local_mesh.size, 8) - - process_index_fn.return_value = 0 - local_mesh = global_mesh.local_mesh - lds = np.array( - [ - [ - TpuDevice( - id=0, process_index=0, coords=(0, 0, 0), core_on_chip=0 - ), - TpuDevice( - id=1, process_index=0, coords=(0, 0, 0), core_on_chip=1 - ), - TpuDevice( - id=2, process_index=0, coords=(1, 0, 0), core_on_chip=0 - ), - TpuDevice( - id=3, process_index=0, coords=(1, 0, 0), core_on_chip=1 - ), - ], - [ - TpuDevice( - id=8, process_index=0, coords=(0, 1, 0), core_on_chip=0 - ), - TpuDevice( - id=9, process_index=0, coords=(0, 1, 0), core_on_chip=1 - ), - TpuDevice( - id=10, process_index=0, coords=(1, 1, 0), core_on_chip=0 - ), - TpuDevice( - id=11, process_index=0, coords=(1, 1, 0), core_on_chip=1 - ), - ], - ], - dtype=object, - ) - np.testing.assert_array_equal(local_mesh.devices, lds) - - @unittest.skipIf(jax.__version_info__ < (0, 4, 5), 'Test requires jax 0.4.5') - @mock.patch('jax.local_devices') - @mock.patch('jax.devices') - @mock.patch(f'{jax.process_index.__module__}.process_index') - def test_local_chunker(self, process_index_fn, devices_fn, local_devices_fn): - devices_fn.return_value = TPUV3_32 - local_devices_fn.return_value = [ - d for d in TPUV3_32 if d.process_index == 0 - ] - process_index_fn.return_value = 0 - global_mesh = partitioning.default_mesh(4) - local_chunker = partitioning.LocalChunker(global_mesh) - self.assertEqual(local_chunker.num_chunks['data'], 4) - self.assertEqual(local_chunker.num_chunks['model'], 1) - - # Derive the chunk order along the first 'data' dim for testing. - host_ordering = [] - for d in global_mesh.devices[:, 0]: - if d.process_index not in host_ordering: - host_ordering.append(d.process_index) - process_index_to_data_pos = { - process_index: idx for idx, process_index in enumerate(host_ordering) - } - - for process_indexx in (0, 1, 2, 3): - process_index_fn.return_value = process_indexx - global_mesh = partitioning.default_mesh(4) - local_chunker = partitioning.LocalChunker(global_mesh) - # get expected chunk for 'data' axis. - expected_chunk = process_index_to_data_pos[process_indexx] - self.assertEqual(local_chunker.chunk_ids['data'], expected_chunk) - self.assertEqual(local_chunker.chunk_ids['model'], 0) - # Sharded along both axes. - local_chunk_info = local_chunker.get_local_chunk_info( - (128, 16), ['data', 'model'] - ) - self.assertEqual(local_chunk_info.replica_id, 0) - self.assertEqual( - local_chunk_info.slice, - (slice(32 * expected_chunk, 32 * (expected_chunk + 1)), slice(0, 16)), - ) - # Replicated across first axis. - local_chunk_info = local_chunker.get_local_chunk_info( - (128, 16), [None, 'model'] - ) - self.assertEqual(local_chunk_info.replica_id, expected_chunk) - self.assertEqual(local_chunk_info.slice, (slice(None), slice(0, 16))) - - -class ModelBasedPartitionerTest(parameterized.TestCase): - - def get_axes_spec(self, partitioner, factored, momentum): - opt_def = adafactor.Adafactor( - learning_rate=0.1, - factored=factored, - min_dim_size_to_factor=8, - beta1=0.1 if momentum else None, - logical_factor_rules={ - 'batch': adafactor.FactorDim.NONE, - 'embed': adafactor.FactorDim.ROW, - 'vocab': adafactor.FactorDim.COLUMN, - 'mlp': adafactor.FactorDim.COLUMN, - }, - ) - state = train_state.FlaxOptimTrainState.create( - opt_def, - flax.core.freeze({ - 'params': { - 'logits_dense': np.ones((16, 16), np.float32), - 'mlp': {'wo': {'kernel': np.ones((32, 16), np.float32)}}, - }, - 'params_axes': { - 'logits_dense_axes': AxisMetadata(names=('vocab', 'embed')), - 'mlp': { - 'wo': {'kernel_axes': AxisMetadata(names=('embed', 'mlp'))} - }, - }, - }), - ) - return partitioner.get_mesh_axes(state).state_dict() - - def get_expected_axes_spec( - self, spec_0, spec_1, kernel_spec=PartitionSpec(None, 'model') - ): - return train_state.FlaxOptimTrainState( - optimizers.Optimizer( - # opt_def, - adafactor.Adafactor(0.1), # opt_def not compared. - state=optimizers.OptimizerState( - step=None, - param_states={ - 'logits_dense': spec_0, - 'mlp': {'wo': {'kernel': spec_1}}, - }, - ), - target={ - 'logits_dense': PartitionSpec('model', None), - 'mlp': {'wo': {'kernel': kernel_spec}}, - }, - ) - ).state_dict() - - def test_get_mesh_axes(self): - partitioner = partitioning.PjitPartitioner( - num_partitions=1, - logical_axis_rules=( - ('batch', 'data'), - ('embed', None), - ('vocab', 'model'), - ('mlp', 'model'), - ), - ) - - p0_spec = PartitionSpec('model', None) - p1_spec = PartitionSpec(None, 'model') - - # Test quadrant of conditions: factored or not / momentum or not. - axes_spec = self.get_axes_spec(partitioner, factored=True, momentum=False) - expected_axes_spec = self.get_expected_axes_spec( - adafactor._AdafactorParamState(m=None, v=None, v_col=None, v_row=None), - adafactor._AdafactorParamState(m=None, v=None, v_col=None, v_row=None), - ) - jax.tree.map(self.assertEqual, axes_spec, expected_axes_spec) - - axes_spec = self.get_axes_spec(partitioner, factored=True, momentum=True) - expected_axes_spec = self.get_expected_axes_spec( - adafactor._AdafactorParamState( - m=p0_spec, v=None, v_col=None, v_row=None - ), - adafactor._AdafactorParamState( - m=p1_spec, v=None, v_col=None, v_row=None - ), - ) - jax.tree.map(self.assertEqual, axes_spec, expected_axes_spec) - - axes_spec = self.get_axes_spec(partitioner, factored=False, momentum=True) - expected_axes_spec = self.get_expected_axes_spec( - adafactor._AdafactorParamState( - m=p0_spec, v=p0_spec, v_col=None, v_row=None - ), - adafactor._AdafactorParamState( - m=p1_spec, v=p1_spec, v_col=None, v_row=None - ), - ) - jax.tree.map(self.assertEqual, axes_spec, expected_axes_spec) - - axes_spec = self.get_axes_spec(partitioner, factored=False, momentum=False) - expected_axes_spec = self.get_expected_axes_spec( - adafactor._AdafactorParamState( - m=None, v=p0_spec, v_col=None, v_row=None - ), - adafactor._AdafactorParamState( - m=None, v=p1_spec, v_col=None, v_row=None - ), - ) - jax.tree.map(self.assertEqual, axes_spec, expected_axes_spec) - - @parameterized.product(activation_dims=(1, 2), param_dims=(1, 2)) - def test_standard_logical_axis_rules(self, activation_dims, param_dims): - default_rules = partitioning.standard_logical_axis_rules( - activation_dims, param_dims, additional_rules=None - ) - custom_rules = ( - ('my-new-axis', 'data'), - ('another-axis', None), - ('another-one', 'model'), - ) - new_rules = partitioning.standard_logical_axis_rules( - activation_dims, param_dims, additional_rules=custom_rules - ) - self.assertEqual(new_rules[: len(default_rules)], default_rules) - self.assertEqual(new_rules[len(default_rules) :], list(custom_rules)) - - -if __name__ == '__main__': - absltest.main() diff --git a/t5x-main/t5x/precompile.py b/t5x-main/t5x/precompile.py deleted file mode 100644 index eb7664b33b596d6df6135977af73be1e6d260eb0..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/precompile.py +++ /dev/null @@ -1,146 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -r"""Precompile and generates HLO from TPU metadata backend. - -TPU Metadata backend is a TPU backend without real TPU devices while supporting -any TPU topologies, to allow work that doesn't require real TPUs to run as if -it is, e.g., compiling/lowering a HLO graph with the backend. - -Ideally, the precompile defaults to cpu backend for default device array -placement since metadata backend does not have memory allocation. - -The pjit function is pinned to use available TPU Metadata backend, for getting -a proper lowering under TPU mesh. - -""" - -import os -from typing import Callable, Optional - -import clu.data - -import jax -from jax import random -import numpy as np -import t5.data.mixtures # pylint:disable=unused-import -from t5x import models -from t5x import partitioning -from t5x import trainer as trainer_lib -from t5x import utils -import tensorflow as tf - - - -def precompile( - *, - model: models.BaseTransformerModel, - train_dataset_cfg: utils.DatasetConfig, - partitioner: partitioning.BasePartitioner, - model_dir: str, - random_seed: Optional[int], - get_dataset_fn: utils.GetDatasetCallable = utils.get_dataset, - verify_matching_vocabs_fn: Optional[ - Callable[[utils.DatasetConfig, models.BaseTransformerModel], None] - ] = utils.verify_matching_vocabs, -): - """Compiles and dump the HLO to model dir, with HLO text dumps.""" - rng = random.PRNGKey(random_seed or 42) - _, trainer_rng = random.split(rng, 2) - - # TODO(hthu): Find a better way of getting dataset shapes instead of actually - # reading database and iterate on it. - data_layout = partitioner.get_data_layout(train_dataset_cfg.batch_size) - ds_shard_id = data_layout.shard_id - num_ds_shards = data_layout.num_shards - - if verify_matching_vocabs_fn is not None: - verify_matching_vocabs_fn(train_dataset_cfg, model) - - train_iter = get_dataset_fn( - train_dataset_cfg, ds_shard_id, num_ds_shards, model.FEATURE_CONVERTER_CLS - ) - if isinstance(train_iter, tf.data.Dataset): - train_iter = clu.data.TfDatasetIterator(train_iter, checkpoint=True) - elif not isinstance(train_iter, clu.data.dataset_iterator.DatasetIterator): - raise ValueError( - f'get_dataset_fn returned unsupported type {type(train_iter)}.' - ) - - # Need to use full batch size. - input_shapes = jax.tree.map( - lambda x: (data_layout.batch_size, *x.shape[1:]), train_iter.element_spec - ) - input_types = jax.tree.map(lambda x: x.dtype, train_iter.element_spec) - dummy_batch = jax.tree.map( - lambda x: np.ones(x.shape, x.dtype), train_iter.element_spec - ) - - # Compiling does not care about loading real weights. - train_state_initializer = utils.TrainStateInitializer( - optimizer_def=model.optimizer_def, - init_fn=model.get_initial_variables, - input_shapes=input_shapes, - input_types=input_types, - partitioner=partitioner, - ) - train_state_shape = train_state_initializer.global_train_state_shape - train_state_axes = train_state_initializer.train_state_axes - - def train_step(train_state, batch): - return trainer_lib.train_with_lr( # pytype: disable=wrong-arg-types # jax-ndarray - train_state, - batch, - learning_rate=1e-3, - dropout_rng=trainer_rng, - model=model, - num_microbatches=None, - weight_metrics_computer=None, - ) - - partitioned_step = partitioner.partition( - train_step, - in_axis_resources=( - train_state_axes, - partitioning.PartitionSpec( - 'data', - ), - ), - out_axis_resources=(train_state_axes, None), - donate_argnums=(0,), - ) - - # PartitionedTrainCallable has lower() defined but isn't exposed in pytype. - # TODO(hthu): Explicitly expose the lower() interface. - # pytype: disable=attribute-error - lowered = partitioned_step.lower(train_state_shape, dummy_batch) - # pytype: enable=attribute-error - - - # TODO(hthu): Make this a proper library without writing files by default. - tf.io.gfile.makedirs(model_dir) - with tf.io.gfile.GFile( - os.path.join(model_dir, 'lowered_hlo_pre_optimization'), 'w' - ) as f: - f.write(lowered.compiler_ir(dialect='hlo').as_serialized_hlo_module_proto()) - compiled = lowered.compile() - output_path = os.path.join(model_dir, 'lowered_hlo_post_optimization') - with tf.io.gfile.GFile(output_path, 'w') as f: - f.write( - compiled.runtime_executable() - .hlo_modules()[0] - .as_serialized_hlo_module_proto() - ) - with tf.io.gfile.GFile(os.path.join(model_dir, 'assignment'), 'wb') as f: - np.save(f, partitioner.mesh.device_ids) diff --git a/t5x-main/t5x/scripts/__init__.py b/t5x-main/t5x/scripts/__init__.py deleted file mode 100644 index 548e50465de0fcf5c81a4b08186d8164f705908d..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/scripts/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""This empty file is needed to be recognized as a package by the setuptools.""" diff --git a/t5x-main/t5x/scripts/convert_tf_checkpoint.py b/t5x-main/t5x/scripts/convert_tf_checkpoint.py deleted file mode 100644 index e5a0caacef7227651d57dc20b57aab431a680f9b..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/scripts/convert_tf_checkpoint.py +++ /dev/null @@ -1,130 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -r"""Tool to convert a T5/MeshTF checkpoint to T5X. - -While T5X can be load these checkpoints on-the-fly, the process can be slow -for very large checkpoints. For frequently used checkpoints, it's recommended to -convert them once to a T5X checkpoint. - -Example usage: - -CUDA_VISIBLE_DEVICES="" -python -m t5x.convert_tf_checkpoint \ - --gin_file=t5x/examples/t5/t5_1_0/small.gin\ - --gin.convert_checkpoint.model=%MODEL\ - --gin.convert_checkpoint.tf_checkpoint_path=\ -\"gs://t5-data/pretrained_models/small/model.ckpt-1000000\"\ - --gin.convert_checkpoint.output_dir=\"/tmp/t5x_checkpoints/t5_small\"\ - --logtostderr -""" - -import jax -import jax.numpy as jnp -from t5x import checkpoints -from t5x import models -from t5x import partitioning -from t5x import train_state as train_state_lib - - -def convert_checkpoint( - model: models.BaseModel, - tf_checkpoint_path: str, - output_dir: str, - save_dtype: jnp.dtype = jnp.float32, - concurrent_gb: int = 16, -): - """Converts a TensorFlow checkpoint to a P5X checkpoint. - - Args: - model: - tf_checkpoint_path: Path to a TensorFlow checkpoint to convert. - output_dir: Path to a directory to write the converted checkpoint. - save_dtype: What dtype to store the target parameters as. - concurrent_gb: Number of gigabtes of parameters to convert in parallel. - Actual RAM usage may be 4X this number. - """ - - def initialize_train_state(rng): - initial_variables = model.get_initial_variables( # pytype: disable=wrong-arg-types # jax-array - rng=rng, - input_shapes={ - 'encoder_input_tokens': (1, 1), - 'decoder_input_tokens': (1, 1), - }, - ) - return train_state_lib.FlaxOptimTrainState.create( - model.optimizer_def, initial_variables - ) - - train_state = jax.eval_shape(initialize_train_state, jax.random.PRNGKey(0)) - - partitioner = partitioning.PjitPartitioner(1) - - checkpointer = checkpoints.Checkpointer( - train_state, partitioner, output_dir, save_dtype=jnp.dtype(save_dtype) - ) - - checkpointer.convert_from_tf_checkpoint( - tf_checkpoint_path, concurrent_gb=concurrent_gb - ) - - -if __name__ == '__main__': - # pylint:disable=g-import-not-at-top - from absl import flags - import gin - from t5x import gin_utils - # pylint:disable=g-import-not-at-top - - FLAGS = flags.FLAGS - - jax.config.parse_flags_with_absl() - - flags.DEFINE_multi_string( - 'gin_file', - default=None, - help=( - 'Path to gin configuration file. Multiple paths may be passed and ' - 'will be imported in the given order, with later configurations ' - 'overriding earlier ones.' - ), - ) - - flags.DEFINE_multi_string( - 'gin_bindings', default=[], help='Individual gin bindings' - ) - - flags.DEFINE_list( - 'gin_search_paths', - default=['t5x/configs'], - help=( - 'Comma-separated list of gin config path prefixes to be prepended ' - 'to suffixes given via `--gin_file`. If a file appears in. Only the ' - 'first prefix that produces a valid path for each suffix will be ' - 'used.' - ), - ) - - def main(_): - """True main function.""" - convert_checkpoint_using_gin = gin.configurable(convert_checkpoint) - - gin_utils.parse_gin_flags( - FLAGS.gin_search_paths, FLAGS.gin_file, FLAGS.gin_bindings - ) - # Get gin-configured version of `convert_checkpoint`. - convert_checkpoint_using_gin() - - gin_utils.run(main) diff --git a/t5x-main/t5x/scripts/xm_launch.py b/t5x-main/t5x/scripts/xm_launch.py deleted file mode 100644 index 8e21af4aa24f2976c4a4f5a0de3eb739b7bbe4fb..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/scripts/xm_launch.py +++ /dev/null @@ -1,238 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -r"""XManager launcher for t5x. - -Read about XManager: -https://github.com/deepmind/xmanager - -Usage: -xmanager xm_launch.py -- \ - --gin_file="t5x/examples/t5/t5_1_1/examples/base_wmt_from_scratch.gin" \ - --model_dir=gs://$GOOGLE_CLOUD_BUCKET_NAME/t5x/$(date +%Y%m%d) \ - --tfds_data_dir=gs://$GOOGLE_CLOUD_BUCKET_NAME/t5x/data -""" - -import collections -import os -import shutil -import sys -import tempfile -from typing import Any, Dict - -from absl import app -from absl import flags -from xmanager import xm -from xmanager import xm_local -from xmanager.contrib import copybara - -_NAME = flags.DEFINE_string( - 'name', - 't5x', - 'Name of the experiment.', -) -_RUN_MODE = flags.DEFINE_enum( - 'run_mode', - 'train', - ['train', 'eval', 'infer'], - 'The mode to run T5X under', -) -_CLONE_GITHUB = flags.DEFINE_bool( - 'clone_github', - False, - 'If True, clone t5x/ from GitHub. Otherwise, use the local version.', -) -_COPYBARA_CONFIG = flags.DEFINE_string( - 'copybara_config', - None, - 'Copybara config to use. See https://github.com/google/copybara ' - 'If None, the local t5x directory will be copied with no modifications.', -) -_COPYBARA_WORKFLOW = flags.DEFINE_string( - 'copybara_workflow', - 'local', - 'Copybara workflow to apply with --copybara_config', -) -_COPYBARA_ORIGIN = flags.DEFINE_string( - 'copybara_origin', - '..', - 'Copybara origin folder to apply with --copybara_config', -) - -_TPU_CORES = flags.DEFINE_integer( - 'tpu_cores', - 8, - 'Number of TPU cores to run. There will be a new worker every 8 cores. ' - 'TPU types: https://cloud.google.com/tpu/docs/types-zones#types', -) -_MODEL_DIR = flags.DEFINE_string( - 'model_dir', - None, - 'Model dir to save logs, ckpts, etc. in "gs://model_dir" format.', -) -_TFDS_DATA_DIR = flags.DEFINE_string( - 'tfds_data_dir', - None, - 'Data dir to save the processed dataset in "gs://data_dir" format.', -) -_SEQIO_CACHE_DIRS = flags.DEFINE_list( - 'seqio_additional_cache_dirs', - [], - 'Comma separated directories in "gs://cache_dir" format to search for' - ' cached Tasks in addition to defaults.', -) -_PROJECT_DIRS = flags.DEFINE_list( - 'project_dirs', - None, - 'Project dir with custom components.', -) -_PIP_INSTALL = flags.DEFINE_list( - 'pip_install', - None, - 'Extra pip packages to install.', -) - - -@xm.run_in_asyncio_loop -async def main(_, gin_args: Dict[str, Any]): - name = 't5x' - async with xm_local.create_experiment(experiment_title=name) as experiment: - # TODO(chenandrew) Vertex Tensorboard is not supported for TPUs. - # https://github.com/deepmind/xmanager/issues/11 - # vertex = xm_local.vertex_client() - # tensorboard_name = await vertex.get_or_create_tensorboard(name) - # tensorboard = xm_local.TensorboardCapability( - # name=tensorboard_name, - # base_output_directory=_MODEL_DIR.value) - tensorboard = None - executor = xm_local.Vertex( - requirements=xm.JobRequirements(tpu_v2=_TPU_CORES.value), - tensorboard=tensorboard, - ) - - staging = os.path.join(tempfile.mkdtemp(), _NAME.value) - os.makedirs(staging) - # The t5x/ root directory. - t5x_path = os.path.abspath(os.path.join(__file__, '..', '..', '..')) - t5x_destination = os.path.join(staging, 't5x') - if _COPYBARA_CONFIG.value: - t5x_path = copybara.run_workflow( - _COPYBARA_CONFIG.value, - _COPYBARA_WORKFLOW.value, - _COPYBARA_ORIGIN.value, - t5x_destination, - ) - - if _CLONE_GITHUB.value: - copy_t5x = [ - 'RUN git clone --branch=main https://github.com/google-research/t5x', - ] - else: - if t5x_path != t5x_destination: - shutil.copytree(t5x_path, t5x_destination) - staging_t5x_path = os.path.join(os.path.basename(staging), 't5x') - copy_t5x = [f'COPY {staging_t5x_path}/ t5x'] - - copy_projects = [] - if _PROJECT_DIRS.value: - for project_dir in _PROJECT_DIRS.value: - project_name = os.path.basename(project_dir) - shutil.copytree(project_dir, os.path.join(staging, project_name)) - staging_project_dir = os.path.join( - os.path.basename(staging), project_name - ) - copy_projects.append(f'COPY {staging_project_dir}/ {project_name}') - - pip_install = [] - if _PIP_INSTALL.value: - pip_install = [ - 'RUN python3 -m pip install ' + ' '.join(_PIP_INSTALL.value) - ] - - [executable] = experiment.package([ - xm.python_container( - executor.Spec(), - path=staging, - # TODO(chenandrew): deeplearning image is still on python3.7 - # base_image='gcr.io/deeplearning-platform-release/base-cpu', - base_image='python:3.10', - docker_instructions=[ - *copy_t5x, - 'WORKDIR t5x', - # Install gcloud. This is normally part of deeplearning image. - # Since we use python:3.10, we need to do this manually. - 'RUN apt-get install apt-transport-https ca-certificates gnupg', - ( - 'RUN echo "deb' - ' [signed-by=/usr/share/keyrings/cloud.google.gpg]' - ' http://packages.cloud.google.com/apt cloud-sdk main" |' - ' tee -a /etc/apt/sources.list.d/google-cloud-sdk.list &&' - ' curl' - ' https://packages.cloud.google.com/apt/doc/apt-key.gpg |' - ' apt-key --keyring /usr/share/keyrings/cloud.google.gpg ' - ' add - && apt-get update -y && apt-get install' - ' google-cloud-cli -y' - ), - ( - 'RUN python3 -m pip install -e ".[tpu]" -f' - ' https://storage.googleapis.com/jax-releases/libtpu_releases.html' - ), - *pip_install, - *copy_projects, - ], - entrypoint=xm.CommandList([ - f'export MODEL_DIR=\'"{_MODEL_DIR.value}/logs"\'', - f'export TFDS_DATA_DIR={_TFDS_DATA_DIR.value}', - 'export SEQIO_CACHE_DIRS={}'.format( - ','.join(_SEQIO_CACHE_DIRS.value) - ), - 'export T5X_DIR=.', - ( - 'python3 ${T5X_DIR}/t5x/main.py ' - f'--run_mode={_RUN_MODE.value} ' - '--gin.MODEL_DIR=${MODEL_DIR} ' - '--tfds_data_dir=${TFDS_DATA_DIR} ' - '--undefok=seqio_additional_cache_dirs ' - '--seqio_additional_cache_dirs=${SEQIO_CACHE_DIRS} ' - ), - ]), - ), - ]) - args = [] - for k, l in gin_args.items(): - for v in l: - if "'" or '"' in v: - args.append(xm.ShellSafeArg(f'--{k}={v}')) - else: - args.append(f'--{k}={v}') - - experiment.add(xm.Job(executable=executable, executor=executor, args=args)) - - -def _split_gin_args(argv, prefix='--gin'): - """Separates absl and gin args into separate lists.""" - other_args = [argv[0]] - gin_args = collections.defaultdict(list) - for arg in argv[1:]: - if arg.startswith(prefix): - k, v = arg[len('--') :].split('=', maxsplit=1) - gin_args[k].append(v) - else: - other_args.append(arg) - return other_args, gin_args - - -if __name__ == '__main__': - _other_args, _gin_args = _split_gin_args(sys.argv) - app.run(lambda argv: main(argv, _gin_args), _other_args) diff --git a/t5x-main/t5x/state_utils.py b/t5x-main/t5x/state_utils.py deleted file mode 100644 index aa1af463777b8980607714ea18d0e2321238e007..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/state_utils.py +++ /dev/null @@ -1,238 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Utilities for processing optimizer states.""" - -import re -from typing import Any, Mapping, Optional, Sequence, Tuple - -from absl import logging -from flax import traverse_util - - -def tensorstore_leaf(_, value): - """Detect if the node is a serialized tensorstore spec. - - Args: - _: The unused name of the current item. - value: The value of the possible leaf. - - Returns: - True if the value represents a tensorstore spec, False otherwise. - """ - # It is a tensorstore leaf if it at least has `driver`, `kvstore` and - # `metadata` in its keys, sometime they have additional ones like `dtype` or - # `transform`. - return set(value.keys()) >= {"driver", "kvstore", "metadata"} - - -def flatten_state_dict(state_dict, keep_empty_nodes: bool = False): - """Flatten a dictionary until an array or tensorstore is reached. - - Args: - state_dict: Optimizer state as nested dictionary. - keep_empty_nodes: Whether to keep empty node, for example, empty param - states from simple optimizers or non-touched parameter states in a - multioptimizer. - - Returns: - Flattened dictionary, though keeping tensor store state unflattened. - """ - return traverse_util.flatten_dict( - state_dict, - is_leaf=tensorstore_leaf, - keep_empty_nodes=keep_empty_nodes, - sep="/", - ) - - -def get_name_tree(state_dict, keep_empty_nodes: bool = False): - """Returns new state_dict with leaves as joined path keys separated by "/".""" - return traverse_util.unflatten_dict({ - k: "/".join(k) - for k in traverse_util.flatten_dict( - state_dict, keep_empty_nodes=keep_empty_nodes - ) - }) - - -def intersect_state( - state_dict: Mapping[str, Any], intersect_state_dict: Mapping[str, Any] -) -> Mapping[str, Any]: - """Drops non-matching entries from `state_dict`. - - Args: - state_dict: nested dict of optimizer state - intersect_state_dict: nested dict of entries to keep - - Returns: - nested dict like `state_dict` but with extra keys removed - """ - state_dict_flat = flatten_state_dict(state_dict) - intersect_state_dict_flat = flatten_state_dict(intersect_state_dict) - - for k in list(state_dict_flat): - if k not in intersect_state_dict_flat: - state_dict_flat.pop(k) - logging.warning( - "Not restoring param=%s because it's missing in the checkpoint", k - ) - - state_dict = traverse_util.unflatten_dict(state_dict_flat, sep="/") - - return state_dict - - -def merge_state( - state_dict: Mapping[str, Any], - from_scratch_state: Mapping[str, Any], - overwrite: bool = False, -) -> Mapping[str, Any]: - """Inserts new entries into `state_dict`. - - Args: - state_dict: nested dict of optimizer state - from_scratch_state: nested dict of entries to insert - overwrite: if True, values present in both state_dict and from_scratch_state - will be present in the result with the value taken from - `from_scratch_state`. - - Returns: - a nested dict like `state_dict` but with extra entries from - `from_scratch_state` inserted - """ - state_dict_flat = flatten_state_dict(state_dict) - from_scratch_state_flat = flatten_state_dict(from_scratch_state) - - for k in from_scratch_state_flat: - if k not in state_dict_flat or overwrite: - logging.warning("Initializing param=%s from scratch", k) - state_dict_flat[k] = from_scratch_state_flat[k] - - state_dict = traverse_util.unflatten_dict(state_dict_flat, sep="/") - - return state_dict - - -def apply_assignment_map( - ckpt_optimizer_state, - optimizer_state, - assignment_map: Sequence[Tuple[str, Optional[str]]], - require_all_rules_match: bool = True, - *, - is_resuming: bool = False, -): - """Applies an assignment map to a checkpoint optimizer state. - - In contrast to previous implementations, this has a switch whether to require - that all rules match, and has somewhat-custom-but-sensible replacement rules: - - 1. old keys that are matched are removed. - 2. old keys that don't match are retained. - 3. if two new keys map to the same old key, they both get assigned to its - value. - 4. if a new key isn't mapped but is in the checkpoint, it is copied over. - 5. new keys with None-valued replacement patterns are removed. - - Args: - ckpt_optimizer_state: Optimizer state in the checkpoint (usually, previous - model). - optimizer_state: optimizer state in the current model. - assignment_map: List of tuples (matcher, replacement) where matcher is a - regex, and replacement is a string replacement (possibly with - regex-compatible group match codes) or None if the matching state should - be dropped. - require_all_rules_match: Whether to require that all rules match. - is_resuming: Whether we are resuming a training run (True) or initializing a - new one (False). - - Returns: - New, remapped optimizer state. - """ - if is_resuming: - # Do not apply the transformation when resuming after a temporary stop. - # This ensures that the transformation will only happen once. - return ckpt_optimizer_state - - flat_ckpt = flatten_state_dict(ckpt_optimizer_state) - unmapped_old_keys = flat_ckpt.copy() - result = {} - explicitly_skipped_keys = set() - flat_opt = flatten_state_dict(optimizer_state) - - used_patterns = set() - for k in flat_opt: - for pattern, repl in assignment_map: - p_match = re.fullmatch(pattern, k) - if p_match: - # Skip initialization if the replacement pattern for this key is None. - if repl is None: - explicitly_skipped_keys.add(k) - used_patterns.add(pattern) - logging.info( - "Skipping optimizer param=%s, which had a None " - "replacement using pattern=%s in the assignment map.", - k, - pattern, - ) - break - - old_k = p_match.expand(repl) - used_patterns.add(pattern) - - # Remove the old key, but read the value from the original dict since - # it's OK if it was referenced twice. - unmapped_old_keys.pop(old_k, None) - try: - result[k] = flat_ckpt[old_k] - logging.info( - "Assigning checkpoint param=%s to optimizer param=%s " - "using pattern=%s", - old_k, - k, - pattern, - ) - except KeyError: - raise ValueError( - f"Parameter '{old_k}' does not exist in restore checkpoint. " - f"Must be one of: {sorted(flat_ckpt.keys())}" - ) - break - - # Now re-add the unmapped keys. This is a 2-step process so that the `pop()` - # call above doesn't mis-fire if the assignment map "rotates" a chain of keys. - for key, v in unmapped_old_keys.items(): - if key not in explicitly_skipped_keys: - result[key] = v - - # If any new keys weren't mapped, but are in the old checkpoint, copy those. - for key in set(flat_opt) - set(result): - if key in explicitly_skipped_keys: - pass - elif key in flat_ckpt: - result[key] = flat_ckpt[key] - else: - logging.warning( - "Skipping key=%s which did not match assignment map or checkpoint.", - key, - ) - - if require_all_rules_match and len(assignment_map) != len(used_patterns): - unused_patterns = set(p for p, _ in assignment_map) - used_patterns - unused_patterns_str = ", ".join(f"'{p}'" for p in unused_patterns) - raise ValueError( - "Unused patterns in `assignment_map`: {" + unused_patterns_str + "}" - ) - - return traverse_util.unflatten_dict(result, sep="/") diff --git a/t5x-main/t5x/state_utils_test.py b/t5x-main/t5x/state_utils_test.py deleted file mode 100644 index 3d2b1d8fba21156921c6beb48009caf90aeb0b1c..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/state_utils_test.py +++ /dev/null @@ -1,248 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for state_utils.""" - -import re - -from absl.testing import absltest -from absl.testing import parameterized -import numpy as np -from t5x import state_utils - - -class StateUtilsTest(parameterized.TestCase): - - @parameterized.parameters( - dict( - state_dict={"a": {"b": 2, "c": 3}}, - intersect_state_dict={"a": {"b": 4}, "d": 5}, - expect_state={"a": {"b": 2}}, - ) - ) - def test_intersect_state( - self, state_dict, intersect_state_dict, expect_state - ): - actual_state = state_utils.intersect_state(state_dict, intersect_state_dict) - self.assertEqual(actual_state, expect_state) - - @parameterized.parameters( - dict( - state_dict={"a": {"b": 2, "c": 3}}, - merge_state_dict={"a": {"b": 4}, "d": 5}, - expect_state={"a": {"b": 2, "c": 3}, "d": 5}, - ) - ) - def test_merge_state(self, state_dict, merge_state_dict, expect_state): - actual_state = state_utils.merge_state(state_dict, merge_state_dict) - self.assertEqual(actual_state, expect_state) - - def test_tensorstore_leaf(self): - leaf = { - "driver": "zarr", - "kvstore": {"driver": "gfile", "path": "target.bias"}, - "metadata": { - "chunks": [4, 1], - "compressor": {"id": "gzip", "level": 1}, - "dtype": " int: - """Convert grid coordinates to linear index given a dimension ordering. - - Args: - coords: coordinates in minor to major ordering. - bounds: coordinate grid bonuds in SAME minor to major ordering as above. - - Returns: - Linear index for grid point. - """ - # Calculate stride multipliers. - strides = tuple(itertools.accumulate((1,) + bounds[:-1], operator.mul)) - # Sum linear index from strides and coords - return sum(jax.tree.map(lambda x, y: x * y, coords, strides)) - - -def make_devices( - nx: int, - ny: int, - nz: int, - nc: int = 2, - host_layout: Tuple[int, ...] = (2, 2, 1, 2), - kind='TPU v3', -): - """Create mock TPU devices.""" - devices = [] - device_bounds = (nx, ny, nz, nc) - hnx, hny, hnz, hnc = jax.tree.map( - lambda a, b: a // b, device_bounds, host_layout - ) - for x, y, z, c in itertools.product(*map(range, device_bounds)): - hx, hy, hz, hc = jax.tree.map( - lambda a, b: a // b, (x, y, z, c), host_layout - ) - # TODO(levskaya, jekbradbury): verify this id/host ordering on TPU v4 - device_id = coords_to_idx((c, x, y, z), (nc, nx, ny, nz)) # pytype: disable=wrong-arg-types - process_index = coords_to_idx((hc, hx, hy, hz), (hnc, hnx, hny, hnz)) # pytype: disable=wrong-arg-types - devices.append( - TpuDevice( - id=device_id, - process_index=process_index, - coords=(x, y, z), - core_on_chip=c, - platform='tpu', - device_kind=kind, - ) - ) - return devices - - -def make_train_state_base( - *, - step: int, - params: Mapping[str, Any], - param_states: Mapping[str, Any], - flax_optimizer_def: optimizers.OptimizerDefType = optimizers.sgd(0.1), -) -> train_state_lib.TrainState: - """Helper to construct a train state for testing.""" - optimizer = optimizers.Optimizer( - flax_optimizer_def, - state=optimizers.OptimizerState( # pytype: disable=wrong-arg-types # jax-ndarray - step=step, param_states=param_states - ), - target=params, - ) - - return train_state_lib.FlaxOptimTrainState(optimizer) - - -def make_train_state_replicated( - global_input_shape, - step=42, - dtype=np.float32, -): - """Helper to construct a train state for testing.""" - bias = np.ones(global_input_shape, dtype=dtype) - kernel = np.arange(math.prod(global_input_shape), dtype=dtype).reshape( - global_input_shape - ) - train_state = make_train_state_base( - step=np.int32(step), - params={'bias': bias * 2, 'kernel': kernel * 2}, - param_states={ # only cast targets (above) - 'bias': bias.astype(np.float32), - 'kernel': kernel.astype(np.float32), - }, - ) - return train_state - - -def make_train_state( - global_mesh, global_input_shape, mesh_axes, step=42, dtype=np.float32 -): - """Construct a train state for testing.""" - train_state = make_train_state_replicated( - global_input_shape, step=step, dtype=dtype - ) - - return jax.tree.map( - functools.partial( - create_sharded_array, - global_shape=global_input_shape, - global_mesh=global_mesh, - mesh_axes=mesh_axes, - ), - train_state, - is_leaf=lambda x: isinstance(x, np.ndarray), - ) - - -def get_t5_test_model(**config_overrides) -> models.EncoderDecoderModel: - """Returns a tiny T5 1.1 model to use for testing.""" - tiny_config = network.T5Config( - vocab_size=32128, - dtype='bfloat16', - emb_dim=8, - num_heads=4, - num_encoder_layers=2, - num_decoder_layers=2, - head_dim=3, - mlp_dim=16, - mlp_activations=('gelu', 'linear'), - dropout_rate=0.0, - logits_via_embedding=False, - ) - - tiny_config = dataclasses.replace(tiny_config, **config_overrides) - sentencepiece_model_file = ( - 'gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model' - ) - vocabulary = seqio.SentencePieceVocabulary(sentencepiece_model_file) - return models.EncoderDecoderModel( - module=network.Transformer(tiny_config), - input_vocabulary=vocabulary, - output_vocabulary=vocabulary, - optimizer_def=adafactor.Adafactor( - decay_rate=0.8, - step_offset=0, - logical_factor_rules=adafactor.standard_logical_factor_rules(), - ), - ) - - -# -------------------- Mesh parametrization helpers -------------------- -# Adapted from jax.test_util -MeshSpec = List[Tuple[str, int]] - - -@contextlib.contextmanager -def with_mesh(named_shape: MeshSpec) -> Generator[None, None, None]: - """Test utility for setting up meshes given mesh data from `schedules`.""" - axis_names, shape = zip(*named_shape) - size = np.prod(shape) - local_devices = list(jax.local_devices()) - if len(local_devices) < size: - raise unittest.SkipTest(f'Test requires {size} local devices') - mesh_devices = np.array(local_devices[:size]).reshape(shape) - with Mesh(mesh_devices, axis_names): - yield - - -def create_global_mesh(mesh_shape, axis_names): - size = np.prod(mesh_shape) - if len(jax.devices()) < size: - raise unittest.SkipTest(f'Test requires {size} global devices.') - devices = sorted(jax.devices(), key=lambda d: d.id) - mesh_devices = np.array(devices[:size]).reshape(mesh_shape) - global_mesh = Mesh(mesh_devices, axis_names) - return global_mesh - - -def get_fake_vocab(): - """Creates fake vocabulary compatible with `get_fake_tokenized_dataset`.""" - - @dataclasses.dataclass - class DummyVocab: - vocab_size: int = 128 - eos_id: int = 1 - - vocab = DummyVocab() - return (vocab, vocab) - - -# Text preprocessed and tokenized. -_FAKE_TOKENIZED_DATASET = { - 'train': [ - { - 'inputs': (3, 13, 7, 14, 15, 9, 4, 16), - 'inputs_pretokenized': 'complete: this', - 'targets': (3, 8, 6, 3, 5, 10), - 'targets_pretokenized': 'is a test', - }, - { - 'inputs': (3, 13, 7, 14, 15, 9, 4, 16), - 'inputs_pretokenized': 'complete: that', - 'targets': (17, 5, 6, 3, 5, 10), - 'targets_pretokenized': 'was a test', - }, - { - 'inputs': (3, 13, 7, 14, 15, 9, 4, 16), - 'inputs_pretokenized': 'complete: those', - 'targets': (17, 4, 23, 4, 10, 6), - 'targets_pretokenized': 'were tests', - }, - ], - # Notice that we repeat consecutively each examples 4 times, - # this needed for tests like infer_tests to validate determinism. - 'validation': [{ - 'inputs': (3, 13, 7, 14, 15, 9, 4, 16), - 'inputs_pretokenized': 'complete: this', - 'targets': (3, 8, 6, 3, 5, 3, 25, 5), - 'targets_pretokenized': 'is a validation', - }] * 4 + [{ - 'inputs': (3, 13, 7, 14, 15, 9, 4, 17), - 'inputs_pretokenized': 'complete: that', - 'targets': (17, 5, 6, 3, 5, 22, 7, 24), - 'targets_pretokenized': 'was another validation', - }] * 4, -} - - -def get_fake_tokenized_dataset(*_, split='validation', **__): - """Creates fake dataset compatible with T5X models inputs.""" - - if split == 'test': - split = 'validation' - output_types = { - 'inputs': tf.int32, - 'targets': tf.int32, - 'inputs_pretokenized': tf.string, - 'targets_pretokenized': tf.string, - } - output_shapes = { - 'inputs': [None], - 'targets': [None], - 'inputs_pretokenized': [], - 'targets_pretokenized': [], - } - ds = tf.data.Dataset.from_generator( - lambda: _FAKE_TOKENIZED_DATASET[split], output_types, output_shapes - ) - if split == 'train': - ds = ds.repeat(None) - return ds - - -def assert_equal(a, b): - """Check equality of LazyArray / jax.Array / other array.""" - assert isinstance( - a, type(b) - ), f'Found incompatible types: {type(a)}, {type(b)}' - if isinstance(a, LazyArray): - a = a.get() - if isinstance(b, LazyArray): - b = b.get() - if not isinstance(a, jax.Array): - np.testing.assert_array_equal(a, b) - else: - for s1, s2 in zip(a.addressable_shards, b.addressable_shards): - np.testing.assert_array_equal(s1.data, s2.data) - - -def assert_same(tree_a, tree_b): - """Asserts that both trees are the same.""" - tree_a, tree_b = jax.device_get((tree_a, tree_b)) - jax.tree.map(assert_equal, tree_a, tree_b) - - -def get_train_state_from_variables( - variables, optimizer_def=adafactor.Adafactor(0.0) -): - """Returns a default Train State with Adafactor optimizer.""" - optimizer = optimizer_def.create(variables['params']) - return train_state_lib.FlaxOptimTrainState(optimizer) - - -def create_sharded_array(arr, global_shape, global_mesh, mesh_axes): - def cb(index): - return arr[index] - - if np.isscalar(arr): - return arr - return jax.make_array_from_callback( - global_shape, jax.sharding.NamedSharding(global_mesh, mesh_axes), cb - ) - - -class FakePartitioner(partitioning.BasePartitioner): - """Fake Partitioner for testing.""" - - def __init__(self, mesh, mesh_axes, params_on_devices=True): - super().__init__(num_partitions=1) - self._global_mesh = mesh - self._mesh_axes = mesh_axes - self._local_chunker = partitioning.LocalChunker(self.mesh) - self._params_on_devices = params_on_devices - - def get_data_layout(self, batch_size=None, host_index=None): - return partitioning.DataLayout( - batch_size=1, - shard_id=1, - num_shards=1, - is_first_host_in_replica_set=True, - ) - - @property - def mesh(self): - return self._global_mesh - - @property - def params_on_devices(self): - return self._params_on_devices - - def move_params_to_devices(self, train_state, train_state_axes): - return train_state - - def get_mesh_axes(self, train_state): - mesh_axes = jax.tree.map(lambda _: self._mesh_axes, train_state) - return mesh_axes.replace_step(None) - - def _local_chunker(self): - return self._local_chunker - - def partition( - self, - fn, - in_axis_resources, - out_axis_resources, - static_argnums=(), - donate_argnums=(), - ): - pjitted = pjit( - fn, - in_shardings=in_axis_resources, - out_shardings=out_axis_resources, - static_argnums=static_argnums, - donate_argnums=donate_argnums, - ) - return partitioning.PjittedFnWithContext(pjitted, self.mesh) - - def compile(self, partitioned_fn, *args): - return None - -# -------------------- Checkpoint helpers -------------------- - - -def _train_state_shapes(train_state): - def _maybe_get(x): - if isinstance(x, LazyArray): - return x.get() - return x - - train_state = jax.tree_util.tree_map(_maybe_get, train_state) - return jax.eval_shape(lambda: train_state) - - -def save(checkpointer_or_manager, train_state, force=False): - saved = checkpointer_or_manager.save(train_state, force=force) - checkpointer_or_manager.wait_until_finished() - return saved - - -def create_checkpointer_or_manager( - train_state_shapes, - partitioner, - directory, - dataset_iterator=None, - save_dtype=None, - restore_dtype=None, - best=False, - keep=None, - period=1, - checkpoint_steps=None, - keep_checkpoints_without_metrics=True, -): - """Creates an Orbax CheckpointManagerInterface.""" - metric_name_to_monitor = 'train/accuracy' if best else None - return checkpoints.OrbaxCheckpointManagerInterface( - directory, - train_state_shapes, - partitioner, - dataset_iterator=dataset_iterator, - save_dtype=save_dtype, - restore_dtype=restore_dtype, - keep=keep, - period=period, - checkpoint_steps=checkpoint_steps, - metric_name_to_monitor=metric_name_to_monitor, - keep_checkpoints_without_metrics=keep_checkpoints_without_metrics, - ) diff --git a/t5x-main/t5x/testdata/mtf_tiny_t5/checkpoint b/t5x-main/t5x/testdata/mtf_tiny_t5/checkpoint deleted file mode 100644 index 92c11cc82b86d533cecb5bffaeffa6d8a0dcf484..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/testdata/mtf_tiny_t5/checkpoint +++ /dev/null @@ -1,2 +0,0 @@ -model_checkpoint_path: "model.ckpt-0" -all_model_checkpoint_paths: "model.ckpt-0" diff --git a/t5x-main/t5x/testdata/mtf_tiny_t5/graph.pbtxt b/t5x-main/t5x/testdata/mtf_tiny_t5/graph.pbtxt deleted file mode 100644 index 09e60d1105ff4a83bb179736a370e09f10dee197..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/testdata/mtf_tiny_t5/graph.pbtxt +++ /dev/null @@ -1,243629 +0,0 @@ -node { - name: "global_step/Initializer/zeros" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@global_step" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: 0 - } - } - } -} -node { - name: "global_step" - op: "VarHandleOp" - attr { - key: "_class" - value { - list { - s: "loc:@global_step" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "allowed_devices" - value { - list { - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "shape" - value { - shape { - } - } - } - attr { - key: "shared_name" - value { - s: "global_step" - } - } -} -node { - name: "global_step/IsInitialized/VarIsInitializedOp" - op: "VarIsInitializedOp" - input: "global_step" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "global_step/Assign" - op: "AssignVariableOp" - input: "global_step" - input: "global_step/Initializer/zeros" - attr { - key: "dtype" - value { - type: DT_INT64 - } - } -} -node { - name: "global_step/Read/ReadVariableOp" - op: "ReadVariableOp" - input: "global_step" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT64 - } - } -} -node { - name: "global_step/VarIsInitializedOp" - op: "VarIsInitializedOp" - input: "global_step" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "global_step/cond/Switch" - op: "Switch" - input: "global_step/VarIsInitializedOp" - input: "global_step/VarIsInitializedOp" - attr { - key: "T" - value { - type: DT_BOOL - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - shape { - } - } - } - } -} -node { - name: "global_step/cond/switch_t" - op: "Identity" - input: "global_step/cond/Switch:1" - attr { - key: "T" - value { - type: DT_BOOL - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "global_step/cond/switch_f" - op: "Identity" - input: "global_step/cond/Switch" - attr { - key: "T" - value { - type: DT_BOOL - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "global_step/cond/pred_id" - op: "Identity" - input: "global_step/VarIsInitializedOp" - attr { - key: "T" - value { - type: DT_BOOL - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "global_step/cond/Read/ReadVariableOp" - op: "ReadVariableOp" - input: "global_step/cond/Read/ReadVariableOp/Switch:1" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT64 - } - } -} -node { - name: "global_step/cond/Read/ReadVariableOp/Switch" - op: "Switch" - input: "global_step" - input: "global_step/cond/pred_id" - attr { - key: "T" - value { - type: DT_RESOURCE - } - } - attr { - key: "_class" - value { - list { - s: "loc:@global_step" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - shape { - } - } - } - } -} -node { - name: "global_step/cond/Identity" - op: "Identity" - input: "global_step/cond/Read/ReadVariableOp" - attr { - key: "T" - value { - type: DT_INT64 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "global_step/cond/Switch_1" - op: "Switch" - input: "global_step/Initializer/zeros" - input: "global_step/cond/pred_id" - attr { - key: "T" - value { - type: DT_INT64 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@global_step" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - shape { - } - } - } - } -} -node { - name: "global_step/cond/Merge" - op: "Merge" - input: "global_step/cond/Switch_1" - input: "global_step/cond/Identity" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_INT64 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - shape { - } - } - } - } -} -node { - name: "global_step/add/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: 0 - } - } - } -} -node { - name: "global_step/add" - op: "AddV2" - input: "global_step/cond/Merge" - input: "global_step/add/y" - attr { - key: "T" - value { - type: DT_INT64 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "normalize_element/component_0" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 10000 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - dim { - size: 1 - } - } - string_val: "n/a" - } - } - } -} -node { - name: "TensorSliceDataset" - op: "TensorSliceDataset" - input: "normalize_element/component_0" - device: "/device:CPU:0" - attr { - key: "Toutput_types" - value { - list { - type: DT_STRING - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "buffer_size" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: 16 - } - } - } -} -node { - name: "seed" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: 0 - } - } - } -} -node { - name: "seed2" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: 0 - } - } - } -} -node { - name: "ShuffleDataset" - op: "ShuffleDataset" - input: "TensorSliceDataset" - input: "buffer_size" - input: "seed" - input: "seed2" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "output_types" - value { - list { - type: DT_STRING - } - } - } - attr { - key: "reshuffle_each_iteration" - value { - b: true - } - } -} -node { - name: "cycle_length" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: 16 - } - } - } -} -node { - name: "block_length" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: 16 - } - } - } -} -node { - name: "buffer_output_elements" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: -1 - } - } - } -} -node { - name: "prefetch_input_elements" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: -1 - } - } - } -} -node { - name: "num_parallel_calls" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: -1 - } - } - } -} -node { - name: "ParallelInterleaveDatasetV4" - op: "ParallelInterleaveDatasetV4" - input: "ShuffleDataset" - input: "cycle_length" - input: "block_length" - input: "buffer_output_elements" - input: "prefetch_input_elements" - input: "num_parallel_calls" - device: "/device:CPU:0" - attr { - key: "Targuments" - value { - list { - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "deterministic" - value { - s: "default" - } - } - attr { - key: "f" - value { - func { - name: "__inference_Dataset_interleave_read_file_fn_88" - attr { - key: "_tf_data_function" - value { - b: true - } - } - } - } - } - attr { - key: "output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - shape { - } - shape { - dim { - size: -1 - } - } - shape { - } - } - } - } - attr { - key: "output_types" - value { - list { - type: DT_INT32 - type: DT_STRING - type: DT_INT32 - type: DT_STRING - } - } - } -} -node { - name: "num_shards" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: 1 - } - } - } -} -node { - name: "index" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: 0 - } - } - } -} -node { - name: "ShardDataset" - op: "ShardDataset" - input: "ParallelInterleaveDatasetV4" - input: "num_shards" - input: "index" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - shape { - } - shape { - dim { - size: -1 - } - } - shape { - } - } - } - } - attr { - key: "output_types" - value { - list { - type: DT_INT32 - type: DT_STRING - type: DT_INT32 - type: DT_STRING - } - } - } - attr { - key: "require_non_empty" - value { - b: false - } - } -} -node { - name: "buffer_size_1" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: -1 - } - } - } -} -node { - name: "PrefetchDataset" - op: "PrefetchDataset" - input: "ShardDataset" - input: "buffer_size_1" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@ShardDataset" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "buffer_size_min" - value { - i: 0 - } - } - attr { - key: "legacy_autotune" - value { - b: true - } - } - attr { - key: "output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - shape { - } - shape { - dim { - size: -1 - } - } - shape { - } - } - } - } - attr { - key: "output_types" - value { - list { - type: DT_INT32 - type: DT_STRING - type: DT_INT32 - type: DT_STRING - } - } - } - attr { - key: "slack_period" - value { - i: 0 - } - } -} -node { - name: "num_parallel_calls_1" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: -1 - } - } - } -} -node { - name: "ParallelMapDatasetV2" - op: "ParallelMapDatasetV2" - input: "PrefetchDataset" - input: "num_parallel_calls_1" - device: "/device:CPU:0" - attr { - key: "Targuments" - value { - list { - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "deterministic" - value { - s: "default" - } - } - attr { - key: "f" - value { - func { - name: "__inference_Dataset_map_lambda_123" - attr { - key: "_tf_data_function" - value { - b: true - } - } - } - } - } - attr { - key: "output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - shape { - } - shape { - dim { - size: -1 - } - } - shape { - } - } - } - } - attr { - key: "output_types" - value { - list { - type: DT_INT32 - type: DT_STRING - type: DT_INT32 - type: DT_STRING - } - } - } - attr { - key: "preserve_cardinality" - value { - b: true - } - } - attr { - key: "use_inter_op_parallelism" - value { - b: true - } - } -} -node { - name: "num_parallel_calls_2" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: -1 - } - } - } -} -node { - name: "ParallelMapDatasetV2_1" - op: "ParallelMapDatasetV2" - input: "ParallelMapDatasetV2" - input: "num_parallel_calls_2" - device: "/device:CPU:0" - attr { - key: "Targuments" - value { - list { - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "deterministic" - value { - s: "default" - } - } - attr { - key: "f" - value { - func { - name: "__inference_Dataset_map_lambda_143" - attr { - key: "_tf_data_function" - value { - b: true - } - } - } - } - } - attr { - key: "output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - shape { - } - shape { - dim { - size: -1 - } - } - shape { - } - } - } - } - attr { - key: "output_types" - value { - list { - type: DT_INT32 - type: DT_STRING - type: DT_INT32 - type: DT_STRING - } - } - } - attr { - key: "preserve_cardinality" - value { - b: true - } - } - attr { - key: "use_inter_op_parallelism" - value { - b: true - } - } -} -node { - name: "buffer_size_2" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: 1000 - } - } - } -} -node { - name: "seed_1" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: 0 - } - } - } -} -node { - name: "seed2_1" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: 0 - } - } - } -} -node { - name: "ShuffleDataset_1" - op: "ShuffleDataset" - input: "ParallelMapDatasetV2_1" - input: "buffer_size_2" - input: "seed_1" - input: "seed2_1" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - shape { - } - shape { - dim { - size: -1 - } - } - shape { - } - } - } - } - attr { - key: "output_types" - value { - list { - type: DT_INT32 - type: DT_STRING - type: DT_INT32 - type: DT_STRING - } - } - } - attr { - key: "reshuffle_each_iteration" - value { - b: true - } - } -} -node { - name: "buffer_size_3" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: -1 - } - } - } -} -node { - name: "PrefetchDataset_1" - op: "PrefetchDataset" - input: "ShuffleDataset_1" - input: "buffer_size_3" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@ShuffleDataset_1" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "buffer_size_min" - value { - i: 0 - } - } - attr { - key: "legacy_autotune" - value { - b: true - } - } - attr { - key: "output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - shape { - } - shape { - dim { - size: -1 - } - } - shape { - } - } - } - } - attr { - key: "output_types" - value { - list { - type: DT_INT32 - type: DT_STRING - type: DT_INT32 - type: DT_STRING - } - } - } - attr { - key: "slack_period" - value { - i: 0 - } - } -} -node { - name: "num_parallel_calls_3" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: -1 - } - } - } -} -node { - name: "ParallelMapDatasetV2_2" - op: "ParallelMapDatasetV2" - input: "PrefetchDataset_1" - input: "num_parallel_calls_3" - device: "/device:CPU:0" - attr { - key: "Targuments" - value { - list { - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "deterministic" - value { - s: "default" - } - } - attr { - key: "f" - value { - func { - name: "__inference_Dataset_map__filter_features_159" - attr { - key: "_tf_data_function" - value { - b: true - } - } - } - } - } - attr { - key: "output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - shape { - dim { - size: -1 - } - } - } - } - } - attr { - key: "output_types" - value { - list { - type: DT_INT32 - type: DT_INT32 - } - } - } - attr { - key: "preserve_cardinality" - value { - b: true - } - } - attr { - key: "use_inter_op_parallelism" - value { - b: true - } - } -} -node { - name: "num_parallel_calls_4" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: -1 - } - } - } -} -node { - name: "ParallelMapDatasetV2_3" - op: "ParallelMapDatasetV2" - input: "ParallelMapDatasetV2_2" - input: "num_parallel_calls_4" - device: "/device:CPU:0" - attr { - key: "Targuments" - value { - list { - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "deterministic" - value { - s: "default" - } - } - attr { - key: "f" - value { - func { - name: "__inference_Dataset_map_lambda_175" - attr { - key: "_tf_data_function" - value { - b: true - } - } - } - } - } - attr { - key: "output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - shape { - dim { - size: -1 - } - } - } - } - } - attr { - key: "output_types" - value { - list { - type: DT_INT32 - type: DT_INT32 - } - } - } - attr { - key: "preserve_cardinality" - value { - b: true - } - } - attr { - key: "use_inter_op_parallelism" - value { - b: true - } - } -} -node { - name: "batch_size" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: 512 - } - } - } -} -node { - name: "Const" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - dim { - size: 1 - } - } - int64_val: -1 - } - } - } -} -node { - name: "Const_1" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - dim { - size: 1 - } - } - int64_val: -1 - } - } - } -} -node { - name: "padding_value" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "padding_value_1" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "drop_remainder" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BOOL - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BOOL - tensor_shape { - } - bool_val: false - } - } - } -} -node { - name: "PaddedBatchDatasetV2" - op: "PaddedBatchDatasetV2" - input: "ParallelMapDatasetV2_3" - input: "batch_size" - input: "Const" - input: "Const_1" - input: "padding_value" - input: "padding_value_1" - input: "drop_remainder" - device: "/device:CPU:0" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "Toutput_types" - value { - list { - type: DT_INT32 - type: DT_INT32 - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: -1 - } - } - shape { - dim { - size: -1 - } - dim { - size: -1 - } - } - } - } - } - attr { - key: "parallel_copy" - value { - b: false - } - } -} -node { - name: "num_parallel_calls_5" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: -1 - } - } - } -} -node { - name: "ParallelMapDatasetV2_4" - op: "ParallelMapDatasetV2" - input: "PaddedBatchDatasetV2" - input: "num_parallel_calls_5" - device: "/device:CPU:0" - attr { - key: "Targuments" - value { - list { - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "deterministic" - value { - s: "default" - } - } - attr { - key: "f" - value { - func { - name: "__inference_Dataset_map_custom_pack_batch_206" - attr { - key: "_tf_data_function" - value { - b: true - } - } - } - } - } - attr { - key: "output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: -1 - } - } - shape { - dim { - size: -1 - } - dim { - size: -1 - } - } - shape { - dim { - size: -1 - } - dim { - size: -1 - } - } - shape { - dim { - size: -1 - } - dim { - size: -1 - } - } - shape { - dim { - size: -1 - } - dim { - size: -1 - } - } - shape { - dim { - size: -1 - } - dim { - size: -1 - } - } - } - } - } - attr { - key: "output_types" - value { - list { - type: DT_INT32 - type: DT_INT32 - type: DT_INT32 - type: DT_INT32 - type: DT_INT32 - type: DT_INT32 - } - } - } - attr { - key: "preserve_cardinality" - value { - b: true - } - } - attr { - key: "use_inter_op_parallelism" - value { - b: true - } - } -} -node { - name: "MapDataset" - op: "MapDataset" - input: "ParallelMapDatasetV2_4" - device: "/device:CPU:0" - attr { - key: "Targuments" - value { - list { - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "f" - value { - func { - name: "__inference_Dataset_map_normalize_222" - attr { - key: "_tf_data_function" - value { - b: true - } - } - } - } - } - attr { - key: "output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: -1 - } - } - shape { - dim { - size: -1 - } - dim { - size: -1 - } - } - shape { - dim { - size: -1 - } - dim { - size: -1 - } - } - shape { - dim { - size: -1 - } - dim { - size: -1 - } - } - shape { - dim { - size: -1 - } - dim { - size: -1 - } - } - shape { - dim { - size: -1 - } - dim { - size: -1 - } - } - } - } - } - attr { - key: "output_types" - value { - list { - type: DT_INT32 - type: DT_INT32 - type: DT_INT32 - type: DT_INT32 - type: DT_INT32 - type: DT_INT32 - } - } - } - attr { - key: "preserve_cardinality" - value { - b: true - } - } - attr { - key: "use_inter_op_parallelism" - value { - b: true - } - } -} -node { - name: "UnbatchDataset" - op: "UnbatchDataset" - input: "MapDataset" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - shape { - dim { - size: -1 - } - } - shape { - dim { - size: -1 - } - } - shape { - dim { - size: -1 - } - } - shape { - dim { - size: -1 - } - } - shape { - dim { - size: -1 - } - } - } - } - } - attr { - key: "output_types" - value { - list { - type: DT_INT32 - type: DT_INT32 - type: DT_INT32 - type: DT_INT32 - type: DT_INT32 - type: DT_INT32 - } - } - } -} -node { - name: "num_parallel_calls_6" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: -1 - } - } - } -} -node { - name: "ParallelMapDatasetV2_5" - op: "ParallelMapDatasetV2" - input: "UnbatchDataset" - input: "num_parallel_calls_6" - device: "/device:CPU:0" - attr { - key: "Targuments" - value { - list { - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "deterministic" - value { - s: "default" - } - } - attr { - key: "f" - value { - func { - name: "__inference_Dataset_map_my_fn_250" - attr { - key: "_tf_data_function" - value { - b: true - } - } - } - } - } - attr { - key: "output_shapes" - value { - list { - shape { - dim { - size: 512 - } - } - shape { - dim { - size: 512 - } - } - shape { - dim { - size: 512 - } - } - shape { - dim { - size: 512 - } - } - shape { - dim { - size: 512 - } - } - shape { - dim { - size: 512 - } - } - } - } - } - attr { - key: "output_types" - value { - list { - type: DT_INT32 - type: DT_INT32 - type: DT_INT32 - type: DT_INT32 - type: DT_INT32 - type: DT_INT32 - } - } - } - attr { - key: "preserve_cardinality" - value { - b: true - } - } - attr { - key: "use_inter_op_parallelism" - value { - b: true - } - } -} -node { - name: "num_parallel_calls_7" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: -1 - } - } - } -} -node { - name: "ParallelMapDatasetV2_6" - op: "ParallelMapDatasetV2" - input: "ParallelMapDatasetV2_5" - input: "num_parallel_calls_7" - device: "/device:CPU:0" - attr { - key: "Targuments" - value { - list { - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "deterministic" - value { - s: "default" - } - } - attr { - key: "f" - value { - func { - name: "__inference_Dataset_map_lambda_296" - attr { - key: "_tf_data_function" - value { - b: true - } - } - } - } - } - attr { - key: "output_shapes" - value { - list { - shape { - dim { - size: 512 - } - } - shape { - dim { - size: 512 - } - } - shape { - dim { - size: 512 - } - } - shape { - dim { - size: 512 - } - } - shape { - dim { - size: 512 - } - } - shape { - dim { - size: 512 - } - } - } - } - } - attr { - key: "output_types" - value { - list { - type: DT_INT32 - type: DT_INT32 - type: DT_INT32 - type: DT_INT32 - type: DT_INT32 - type: DT_INT32 - } - } - } - attr { - key: "preserve_cardinality" - value { - b: true - } - } - attr { - key: "use_inter_op_parallelism" - value { - b: true - } - } -} -node { - name: "num_parallel_calls_8" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: -1 - } - } - } -} -node { - name: "ParallelMapDatasetV2_7" - op: "ParallelMapDatasetV2" - input: "ParallelMapDatasetV2_6" - input: "num_parallel_calls_8" - device: "/device:CPU:0" - attr { - key: "Targuments" - value { - list { - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "deterministic" - value { - s: "default" - } - } - attr { - key: "f" - value { - func { - name: "__inference_Dataset_map_lambda_340" - attr { - key: "_tf_data_function" - value { - b: true - } - } - } - } - } - attr { - key: "output_shapes" - value { - list { - shape { - dim { - size: 512 - } - } - shape { - dim { - size: 512 - } - } - shape { - dim { - size: 512 - } - } - shape { - dim { - size: 512 - } - } - shape { - dim { - size: 512 - } - } - shape { - dim { - size: 512 - } - } - } - } - } - attr { - key: "output_types" - value { - list { - type: DT_INT32 - type: DT_INT32 - type: DT_INT32 - type: DT_INT32 - type: DT_INT32 - type: DT_INT32 - } - } - } - attr { - key: "preserve_cardinality" - value { - b: true - } - } - attr { - key: "use_inter_op_parallelism" - value { - b: true - } - } -} -node { - name: "count" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: -1 - } - } - } -} -node { - name: "RepeatDataset" - op: "RepeatDataset" - input: "ParallelMapDatasetV2_7" - input: "count" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "output_shapes" - value { - list { - shape { - dim { - size: 512 - } - } - shape { - dim { - size: 512 - } - } - shape { - dim { - size: 512 - } - } - shape { - dim { - size: 512 - } - } - shape { - dim { - size: 512 - } - } - shape { - dim { - size: 512 - } - } - } - } - } - attr { - key: "output_types" - value { - list { - type: DT_INT32 - type: DT_INT32 - type: DT_INT32 - type: DT_INT32 - type: DT_INT32 - type: DT_INT32 - } - } - } -} -node { - name: "batch_size_1" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: 32 - } - } - } -} -node { - name: "drop_remainder_1" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BOOL - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BOOL - tensor_shape { - } - bool_val: true - } - } - } -} -node { - name: "BatchDatasetV2" - op: "BatchDatasetV2" - input: "RepeatDataset" - input: "batch_size_1" - input: "drop_remainder_1" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 512 - } - } - shape { - dim { - size: 32 - } - dim { - size: 512 - } - } - shape { - dim { - size: 32 - } - dim { - size: 512 - } - } - shape { - dim { - size: 32 - } - dim { - size: 512 - } - } - shape { - dim { - size: 32 - } - dim { - size: 512 - } - } - shape { - dim { - size: 32 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "output_types" - value { - list { - type: DT_INT32 - type: DT_INT32 - type: DT_INT32 - type: DT_INT32 - type: DT_INT32 - type: DT_INT32 - } - } - } - attr { - key: "parallel_copy" - value { - b: false - } - } -} -node { - name: "buffer_size_4" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: -1 - } - } - } -} -node { - name: "PrefetchDataset_2" - op: "PrefetchDataset" - input: "BatchDatasetV2" - input: "buffer_size_4" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@BatchDatasetV2" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "buffer_size_min" - value { - i: 0 - } - } - attr { - key: "legacy_autotune" - value { - b: true - } - } - attr { - key: "output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 512 - } - } - shape { - dim { - size: 32 - } - dim { - size: 512 - } - } - shape { - dim { - size: 32 - } - dim { - size: 512 - } - } - shape { - dim { - size: 32 - } - dim { - size: 512 - } - } - shape { - dim { - size: 32 - } - dim { - size: 512 - } - } - shape { - dim { - size: 32 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "output_types" - value { - list { - type: DT_INT32 - type: DT_INT32 - type: DT_INT32 - type: DT_INT32 - type: DT_INT32 - type: DT_INT32 - } - } - } - attr { - key: "slack_period" - value { - i: 0 - } - } -} -node { - name: "ModelDataset" - op: "ModelDataset" - input: "PrefetchDataset_2" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "algorithm" - value { - i: 0 - } - } - attr { - key: "cpu_budget" - value { - i: 0 - } - } - attr { - key: "output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 512 - } - } - shape { - dim { - size: 32 - } - dim { - size: 512 - } - } - shape { - dim { - size: 32 - } - dim { - size: 512 - } - } - shape { - dim { - size: 32 - } - dim { - size: 512 - } - } - shape { - dim { - size: 32 - } - dim { - size: 512 - } - } - shape { - dim { - size: 32 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "output_types" - value { - list { - type: DT_INT32 - type: DT_INT32 - type: DT_INT32 - type: DT_INT32 - type: DT_INT32 - type: DT_INT32 - } - } - } - attr { - key: "ram_budget" - value { - i: 0 - } - } -} -node { - name: "optimizations_enabled" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - dim { - } - } - } - } - } -} -node { - name: "optimizations_disabled" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - dim { - } - } - } - } - } -} -node { - name: "optimizations_default" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - dim { - size: 3 - } - } - string_val: "map_and_batch_fusion" - string_val: "noop_elimination" - string_val: "shuffle_and_repeat_fusion" - } - } - } -} -node { - name: "OptimizeDatasetV2" - op: "OptimizeDatasetV2" - input: "ModelDataset" - input: "optimizations_enabled" - input: "optimizations_disabled" - input: "optimizations_default" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "optimization_configs" - value { - list { - s: "autotune_buffer_sizes:autotune:true" - s: "batch_parallelization:autotune:true" - s: "disable_prefetch_legacy_autotune:autotune:true" - s: "enable_gradient_descent:autotune:true" - s: "map_parallelization:autotune:true" - } - } - } - attr { - key: "output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 512 - } - } - shape { - dim { - size: 32 - } - dim { - size: 512 - } - } - shape { - dim { - size: 32 - } - dim { - size: 512 - } - } - shape { - dim { - size: 32 - } - dim { - size: 512 - } - } - shape { - dim { - size: 32 - } - dim { - size: 512 - } - } - shape { - dim { - size: 32 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "output_types" - value { - list { - type: DT_INT32 - type: DT_INT32 - type: DT_INT32 - type: DT_INT32 - type: DT_INT32 - type: DT_INT32 - } - } - } -} -node { - name: "IteratorV2" - op: "IteratorV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@BatchDatasetV2" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 512 - } - } - shape { - dim { - size: 32 - } - dim { - size: 512 - } - } - shape { - dim { - size: 32 - } - dim { - size: 512 - } - } - shape { - dim { - size: 32 - } - dim { - size: 512 - } - } - shape { - dim { - size: 32 - } - dim { - size: 512 - } - } - shape { - dim { - size: 32 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "output_types" - value { - list { - type: DT_INT32 - type: DT_INT32 - type: DT_INT32 - type: DT_INT32 - type: DT_INT32 - type: DT_INT32 - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "MakeIterator" - op: "MakeIterator" - input: "OptimizeDatasetV2" - input: "IteratorV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@BatchDatasetV2" - } - } - } -} -node { - name: "IteratorToStringHandle" - op: "IteratorToStringHandle" - input: "IteratorV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@BatchDatasetV2" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "IteratorGetNext" - op: "IteratorGetNext" - input: "IteratorV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@BatchDatasetV2" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 512 - } - } - shape { - dim { - size: 32 - } - dim { - size: 512 - } - } - shape { - dim { - size: 32 - } - dim { - size: 512 - } - } - shape { - dim { - size: 32 - } - dim { - size: 512 - } - } - shape { - dim { - size: 32 - } - dim { - size: 512 - } - } - shape { - dim { - size: 32 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 512 - } - } - shape { - dim { - size: 32 - } - dim { - size: 512 - } - } - shape { - dim { - size: 32 - } - dim { - size: 512 - } - } - shape { - dim { - size: 32 - } - dim { - size: 512 - } - } - shape { - dim { - size: 32 - } - dim { - size: 512 - } - } - shape { - dim { - size: 32 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "output_types" - value { - list { - type: DT_INT32 - type: DT_INT32 - type: DT_INT32 - type: DT_INT32 - type: DT_INT32 - type: DT_INT32 - } - } - } -} -node { - name: "Reshape/shape" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\001\000\000\000 \000\000\000\000\002\000\000" - } - } - } -} -node { - name: "Reshape" - op: "Reshape" - input: "IteratorGetNext" - input: "Reshape/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 32 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "Reshape_1/shape" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\001\000\000\000 \000\000\000\000\002\000\000" - } - } - } -} -node { - name: "Reshape_1" - op: "Reshape" - input: "IteratorGetNext:2" - input: "Reshape_1/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 32 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "Reshape_2/shape" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\001\000\000\000 \000\000\000\000\002\000\000" - } - } - } -} -node { - name: "Reshape_2" - op: "Reshape" - input: "IteratorGetNext:1" - input: "Reshape_2/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 32 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "Reshape_3/shape" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\001\000\000\000 \000\000\000\000\002\000\000" - } - } - } -} -node { - name: "Reshape_3" - op: "Reshape" - input: "IteratorGetNext:3" - input: "Reshape_3/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 32 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "Reshape_4/shape" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\001\000\000\000 \000\000\000\000\002\000\000" - } - } - } -} -node { - name: "Reshape_4" - op: "Reshape" - input: "IteratorGetNext:5" - input: "Reshape_4/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 32 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "Reshape_5/shape" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\001\000\000\000 \000\000\000\000\002\000\000" - } - } - } -} -node { - name: "Reshape_5" - op: "Reshape" - input: "IteratorGetNext:4" - input: "Reshape_5/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 32 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "Const_2" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 4 - } - } - } -} -node { - name: "shared/embedding/Initializer/random_normal/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@shared/embedding" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\200}\000\000 \000\000\000" - } - } - } -} -node { - name: "shared/embedding/Initializer/random_normal/mean" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@shared/embedding" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 0 - } - } - } -} -node { - name: "shared/embedding/Initializer/random_normal/stddev" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@shared/embedding" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 16256 - } - } - } -} -node { - name: "shared/embedding/Initializer/random_normal/RandomStandardNormal" - op: "RandomStandardNormal" - input: "shared/embedding/Initializer/random_normal/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@shared/embedding" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "shared/embedding/Initializer/random_normal/mul" - op: "Mul" - input: "shared/embedding/Initializer/random_normal/RandomStandardNormal" - input: "shared/embedding/Initializer/random_normal/stddev" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@shared/embedding" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "shared/embedding/Initializer/random_normal" - op: "Add" - input: "shared/embedding/Initializer/random_normal/mul" - input: "shared/embedding/Initializer/random_normal/mean" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@shared/embedding" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "shared/embedding" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@shared/embedding" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32128 - } - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "shared/embedding/Assign" - op: "Assign" - input: "shared/embedding" - input: "shared/embedding/Initializer/random_normal" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@shared/embedding" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "shared/embedding/read" - op: "Identity" - input: "shared/embedding" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@shared/embedding" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/rms_norm/scale/Initializer/ones" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - dim { - size: 32 - } - } - half_val: 16256 - } - } - } -} -node { - name: "encoder/block_000/layer_000/rms_norm/scale" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "encoder/block_000/layer_000/rms_norm/scale/Assign" - op: "Assign" - input: "encoder/block_000/layer_000/rms_norm/scale" - input: "encoder/block_000/layer_000/rms_norm/scale/Initializer/ones" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_000/layer_000/rms_norm/scale/read" - op: "Identity" - input: "encoder/block_000/layer_000/rms_norm/scale" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/q/Initializer/random_normal/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200\000\000\000" - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/q/Initializer/random_normal/mean" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 0 - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/q/Initializer/random_normal/stddev" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 15541 - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/q/Initializer/random_normal/RandomStandardNormal" - op: "RandomStandardNormal" - input: "encoder/block_000/layer_000/SelfAttention/q/Initializer/random_normal/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/q/Initializer/random_normal/mul" - op: "Mul" - input: "encoder/block_000/layer_000/SelfAttention/q/Initializer/random_normal/RandomStandardNormal" - input: "encoder/block_000/layer_000/SelfAttention/q/Initializer/random_normal/stddev" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/q/Initializer/random_normal" - op: "Add" - input: "encoder/block_000/layer_000/SelfAttention/q/Initializer/random_normal/mul" - input: "encoder/block_000/layer_000/SelfAttention/q/Initializer/random_normal/mean" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/q" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/q/Assign" - op: "Assign" - input: "encoder/block_000/layer_000/SelfAttention/q" - input: "encoder/block_000/layer_000/SelfAttention/q/Initializer/random_normal" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/q/read" - op: "Identity" - input: "encoder/block_000/layer_000/SelfAttention/q" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/k/Initializer/random_normal/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200\000\000\000" - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/k/Initializer/random_normal/mean" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 0 - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/k/Initializer/random_normal/stddev" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 15925 - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/k/Initializer/random_normal/RandomStandardNormal" - op: "RandomStandardNormal" - input: "encoder/block_000/layer_000/SelfAttention/k/Initializer/random_normal/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/k/Initializer/random_normal/mul" - op: "Mul" - input: "encoder/block_000/layer_000/SelfAttention/k/Initializer/random_normal/RandomStandardNormal" - input: "encoder/block_000/layer_000/SelfAttention/k/Initializer/random_normal/stddev" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/k/Initializer/random_normal" - op: "Add" - input: "encoder/block_000/layer_000/SelfAttention/k/Initializer/random_normal/mul" - input: "encoder/block_000/layer_000/SelfAttention/k/Initializer/random_normal/mean" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/k" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/k/Assign" - op: "Assign" - input: "encoder/block_000/layer_000/SelfAttention/k" - input: "encoder/block_000/layer_000/SelfAttention/k/Initializer/random_normal" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/k/read" - op: "Identity" - input: "encoder/block_000/layer_000/SelfAttention/k" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/v/Initializer/random_normal/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200\000\000\000" - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/v/Initializer/random_normal/mean" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 0 - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/v/Initializer/random_normal/stddev" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 15925 - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/v/Initializer/random_normal/RandomStandardNormal" - op: "RandomStandardNormal" - input: "encoder/block_000/layer_000/SelfAttention/v/Initializer/random_normal/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/v/Initializer/random_normal/mul" - op: "Mul" - input: "encoder/block_000/layer_000/SelfAttention/v/Initializer/random_normal/RandomStandardNormal" - input: "encoder/block_000/layer_000/SelfAttention/v/Initializer/random_normal/stddev" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/v/Initializer/random_normal" - op: "Add" - input: "encoder/block_000/layer_000/SelfAttention/v/Initializer/random_normal/mul" - input: "encoder/block_000/layer_000/SelfAttention/v/Initializer/random_normal/mean" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/v" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/v/Assign" - op: "Assign" - input: "encoder/block_000/layer_000/SelfAttention/v" - input: "encoder/block_000/layer_000/SelfAttention/v/Initializer/random_normal" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/v/read" - op: "Identity" - input: "encoder/block_000/layer_000/SelfAttention/v" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/o/Initializer/random_normal/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\200\000\000\000 \000\000\000" - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/o/Initializer/random_normal/mean" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 0 - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/o/Initializer/random_normal/stddev" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 15797 - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/o/Initializer/random_normal/RandomStandardNormal" - op: "RandomStandardNormal" - input: "encoder/block_000/layer_000/SelfAttention/o/Initializer/random_normal/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/o/Initializer/random_normal/mul" - op: "Mul" - input: "encoder/block_000/layer_000/SelfAttention/o/Initializer/random_normal/RandomStandardNormal" - input: "encoder/block_000/layer_000/SelfAttention/o/Initializer/random_normal/stddev" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/o/Initializer/random_normal" - op: "Add" - input: "encoder/block_000/layer_000/SelfAttention/o/Initializer/random_normal/mul" - input: "encoder/block_000/layer_000/SelfAttention/o/Initializer/random_normal/mean" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/o" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/o/Assign" - op: "Assign" - input: "encoder/block_000/layer_000/SelfAttention/o" - input: "encoder/block_000/layer_000/SelfAttention/o/Initializer/random_normal" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/o/read" - op: "Identity" - input: "encoder/block_000/layer_000/SelfAttention/o" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/Const_1" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 8 - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/minimum/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 15 - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/Initializer/random_uniform/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/relative_attention_bias" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\002\000\000\000 \000\000\000" - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/Initializer/random_uniform/min" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/relative_attention_bias" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 48855 - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/Initializer/random_uniform/max" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/relative_attention_bias" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 16087 - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/Initializer/random_uniform/RandomUniform" - op: "RandomUniform" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/Initializer/random_uniform/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/relative_attention_bias" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/Initializer/random_uniform/sub" - op: "Sub" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/Initializer/random_uniform/max" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/relative_attention_bias" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/Initializer/random_uniform/mul" - op: "Mul" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/Initializer/random_uniform/RandomUniform" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/Initializer/random_uniform/sub" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/relative_attention_bias" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/Initializer/random_uniform" - op: "Add" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/Initializer/random_uniform/mul" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/relative_attention_bias" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/relative_attention_bias" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/Assign" - op: "Assign" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/Initializer/random_uniform" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/relative_attention_bias" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/read" - op: "Identity" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/relative_attention_bias" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/rms_norm/scale/Initializer/ones" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - dim { - size: 32 - } - } - half_val: 16256 - } - } - } -} -node { - name: "encoder/block_000/layer_001/rms_norm/scale" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "encoder/block_000/layer_001/rms_norm/scale/Assign" - op: "Assign" - input: "encoder/block_000/layer_001/rms_norm/scale" - input: "encoder/block_000/layer_001/rms_norm/scale/Initializer/ones" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_000/layer_001/rms_norm/scale/read" - op: "Identity" - input: "encoder/block_000/layer_001/rms_norm/scale" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/Initializer/truncated_normal/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wi_0/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000@\000\000\000" - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/Initializer/truncated_normal/mean" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wi_0/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 0 - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/Initializer/truncated_normal/stddev" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wi_0/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 15925 - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/Initializer/truncated_normal/TruncatedNormal" - op: "TruncatedNormal" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/Initializer/truncated_normal/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wi_0/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/Initializer/truncated_normal/mul" - op: "Mul" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/Initializer/truncated_normal/TruncatedNormal" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/Initializer/truncated_normal/stddev" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wi_0/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/Initializer/truncated_normal" - op: "Add" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/Initializer/truncated_normal/mul" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/Initializer/truncated_normal/mean" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wi_0/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wi_0/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/Assign" - op: "Assign" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/Initializer/truncated_normal" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wi_0/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/read" - op: "Identity" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wi_0/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/Initializer/truncated_normal/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wi_1/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000@\000\000\000" - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/Initializer/truncated_normal/mean" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wi_1/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 0 - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/Initializer/truncated_normal/stddev" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wi_1/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 15925 - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/Initializer/truncated_normal/TruncatedNormal" - op: "TruncatedNormal" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/Initializer/truncated_normal/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wi_1/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/Initializer/truncated_normal/mul" - op: "Mul" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/Initializer/truncated_normal/TruncatedNormal" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/Initializer/truncated_normal/stddev" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wi_1/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/Initializer/truncated_normal" - op: "Add" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/Initializer/truncated_normal/mul" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/Initializer/truncated_normal/mean" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wi_1/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wi_1/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/Assign" - op: "Assign" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/Initializer/truncated_normal" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wi_1/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/read" - op: "Identity" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wi_1/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/Initializer/truncated_normal/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wo/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "@\000\000\000 \000\000\000" - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/Initializer/truncated_normal/mean" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wo/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 0 - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/Initializer/truncated_normal/stddev" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wo/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 15872 - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/Initializer/truncated_normal/TruncatedNormal" - op: "TruncatedNormal" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/Initializer/truncated_normal/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wo/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/Initializer/truncated_normal/mul" - op: "Mul" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/Initializer/truncated_normal/TruncatedNormal" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/Initializer/truncated_normal/stddev" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wo/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/Initializer/truncated_normal" - op: "Add" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/Initializer/truncated_normal/mul" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/Initializer/truncated_normal/mean" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wo/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wo/kernel" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wo/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/Assign" - op: "Assign" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/Initializer/truncated_normal" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wo/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/read" - op: "Identity" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wo/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/rms_norm/scale/Initializer/ones" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - dim { - size: 32 - } - } - half_val: 16256 - } - } - } -} -node { - name: "encoder/block_001/layer_000/rms_norm/scale" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "encoder/block_001/layer_000/rms_norm/scale/Assign" - op: "Assign" - input: "encoder/block_001/layer_000/rms_norm/scale" - input: "encoder/block_001/layer_000/rms_norm/scale/Initializer/ones" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_001/layer_000/rms_norm/scale/read" - op: "Identity" - input: "encoder/block_001/layer_000/rms_norm/scale" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/q/Initializer/random_normal/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200\000\000\000" - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/q/Initializer/random_normal/mean" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 0 - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/q/Initializer/random_normal/stddev" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 15541 - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/q/Initializer/random_normal/RandomStandardNormal" - op: "RandomStandardNormal" - input: "encoder/block_001/layer_000/SelfAttention/q/Initializer/random_normal/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/q/Initializer/random_normal/mul" - op: "Mul" - input: "encoder/block_001/layer_000/SelfAttention/q/Initializer/random_normal/RandomStandardNormal" - input: "encoder/block_001/layer_000/SelfAttention/q/Initializer/random_normal/stddev" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/q/Initializer/random_normal" - op: "Add" - input: "encoder/block_001/layer_000/SelfAttention/q/Initializer/random_normal/mul" - input: "encoder/block_001/layer_000/SelfAttention/q/Initializer/random_normal/mean" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/q" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/q/Assign" - op: "Assign" - input: "encoder/block_001/layer_000/SelfAttention/q" - input: "encoder/block_001/layer_000/SelfAttention/q/Initializer/random_normal" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/q/read" - op: "Identity" - input: "encoder/block_001/layer_000/SelfAttention/q" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/k/Initializer/random_normal/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200\000\000\000" - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/k/Initializer/random_normal/mean" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 0 - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/k/Initializer/random_normal/stddev" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 15925 - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/k/Initializer/random_normal/RandomStandardNormal" - op: "RandomStandardNormal" - input: "encoder/block_001/layer_000/SelfAttention/k/Initializer/random_normal/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/k/Initializer/random_normal/mul" - op: "Mul" - input: "encoder/block_001/layer_000/SelfAttention/k/Initializer/random_normal/RandomStandardNormal" - input: "encoder/block_001/layer_000/SelfAttention/k/Initializer/random_normal/stddev" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/k/Initializer/random_normal" - op: "Add" - input: "encoder/block_001/layer_000/SelfAttention/k/Initializer/random_normal/mul" - input: "encoder/block_001/layer_000/SelfAttention/k/Initializer/random_normal/mean" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/k" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/k/Assign" - op: "Assign" - input: "encoder/block_001/layer_000/SelfAttention/k" - input: "encoder/block_001/layer_000/SelfAttention/k/Initializer/random_normal" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/k/read" - op: "Identity" - input: "encoder/block_001/layer_000/SelfAttention/k" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/v/Initializer/random_normal/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200\000\000\000" - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/v/Initializer/random_normal/mean" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 0 - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/v/Initializer/random_normal/stddev" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 15925 - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/v/Initializer/random_normal/RandomStandardNormal" - op: "RandomStandardNormal" - input: "encoder/block_001/layer_000/SelfAttention/v/Initializer/random_normal/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/v/Initializer/random_normal/mul" - op: "Mul" - input: "encoder/block_001/layer_000/SelfAttention/v/Initializer/random_normal/RandomStandardNormal" - input: "encoder/block_001/layer_000/SelfAttention/v/Initializer/random_normal/stddev" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/v/Initializer/random_normal" - op: "Add" - input: "encoder/block_001/layer_000/SelfAttention/v/Initializer/random_normal/mul" - input: "encoder/block_001/layer_000/SelfAttention/v/Initializer/random_normal/mean" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/v" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/v/Assign" - op: "Assign" - input: "encoder/block_001/layer_000/SelfAttention/v" - input: "encoder/block_001/layer_000/SelfAttention/v/Initializer/random_normal" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/v/read" - op: "Identity" - input: "encoder/block_001/layer_000/SelfAttention/v" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/o/Initializer/random_normal/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\200\000\000\000 \000\000\000" - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/o/Initializer/random_normal/mean" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 0 - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/o/Initializer/random_normal/stddev" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 15797 - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/o/Initializer/random_normal/RandomStandardNormal" - op: "RandomStandardNormal" - input: "encoder/block_001/layer_000/SelfAttention/o/Initializer/random_normal/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/o/Initializer/random_normal/mul" - op: "Mul" - input: "encoder/block_001/layer_000/SelfAttention/o/Initializer/random_normal/RandomStandardNormal" - input: "encoder/block_001/layer_000/SelfAttention/o/Initializer/random_normal/stddev" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/o/Initializer/random_normal" - op: "Add" - input: "encoder/block_001/layer_000/SelfAttention/o/Initializer/random_normal/mul" - input: "encoder/block_001/layer_000/SelfAttention/o/Initializer/random_normal/mean" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/o" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/o/Assign" - op: "Assign" - input: "encoder/block_001/layer_000/SelfAttention/o" - input: "encoder/block_001/layer_000/SelfAttention/o/Initializer/random_normal" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/o/read" - op: "Identity" - input: "encoder/block_001/layer_000/SelfAttention/o" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/Const_1" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 8 - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/minimum/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 15 - } - } - } -} -node { - name: "encoder/block_001/layer_001/rms_norm/scale/Initializer/ones" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - dim { - size: 32 - } - } - half_val: 16256 - } - } - } -} -node { - name: "encoder/block_001/layer_001/rms_norm/scale" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "encoder/block_001/layer_001/rms_norm/scale/Assign" - op: "Assign" - input: "encoder/block_001/layer_001/rms_norm/scale" - input: "encoder/block_001/layer_001/rms_norm/scale/Initializer/ones" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_001/layer_001/rms_norm/scale/read" - op: "Identity" - input: "encoder/block_001/layer_001/rms_norm/scale" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/Initializer/truncated_normal/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wi_0/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000@\000\000\000" - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/Initializer/truncated_normal/mean" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wi_0/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 0 - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/Initializer/truncated_normal/stddev" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wi_0/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 15925 - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/Initializer/truncated_normal/TruncatedNormal" - op: "TruncatedNormal" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/Initializer/truncated_normal/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wi_0/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/Initializer/truncated_normal/mul" - op: "Mul" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/Initializer/truncated_normal/TruncatedNormal" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/Initializer/truncated_normal/stddev" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wi_0/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/Initializer/truncated_normal" - op: "Add" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/Initializer/truncated_normal/mul" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/Initializer/truncated_normal/mean" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wi_0/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wi_0/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/Assign" - op: "Assign" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/Initializer/truncated_normal" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wi_0/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/read" - op: "Identity" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wi_0/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/Initializer/truncated_normal/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wi_1/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000@\000\000\000" - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/Initializer/truncated_normal/mean" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wi_1/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 0 - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/Initializer/truncated_normal/stddev" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wi_1/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 15925 - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/Initializer/truncated_normal/TruncatedNormal" - op: "TruncatedNormal" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/Initializer/truncated_normal/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wi_1/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/Initializer/truncated_normal/mul" - op: "Mul" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/Initializer/truncated_normal/TruncatedNormal" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/Initializer/truncated_normal/stddev" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wi_1/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/Initializer/truncated_normal" - op: "Add" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/Initializer/truncated_normal/mul" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/Initializer/truncated_normal/mean" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wi_1/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wi_1/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/Assign" - op: "Assign" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/Initializer/truncated_normal" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wi_1/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/read" - op: "Identity" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wi_1/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/Initializer/truncated_normal/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wo/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "@\000\000\000 \000\000\000" - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/Initializer/truncated_normal/mean" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wo/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 0 - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/Initializer/truncated_normal/stddev" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wo/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 15872 - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/Initializer/truncated_normal/TruncatedNormal" - op: "TruncatedNormal" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/Initializer/truncated_normal/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wo/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/Initializer/truncated_normal/mul" - op: "Mul" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/Initializer/truncated_normal/TruncatedNormal" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/Initializer/truncated_normal/stddev" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wo/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/Initializer/truncated_normal" - op: "Add" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/Initializer/truncated_normal/mul" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/Initializer/truncated_normal/mean" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wo/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wo/kernel" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wo/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/Assign" - op: "Assign" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/Initializer/truncated_normal" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wo/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/read" - op: "Identity" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wo/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/rms_norm/scale/Initializer/ones" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - dim { - size: 32 - } - } - half_val: 16256 - } - } - } -} -node { - name: "encoder/rms_norm/scale" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "encoder/rms_norm/scale/Assign" - op: "Assign" - input: "encoder/rms_norm/scale" - input: "encoder/rms_norm/scale/Initializer/ones" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/rms_norm/scale/read" - op: "Identity" - input: "encoder/rms_norm/scale" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "decoder/block_000/layer_000/rms_norm/scale/Initializer/ones" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - dim { - size: 32 - } - } - half_val: 16256 - } - } - } -} -node { - name: "decoder/block_000/layer_000/rms_norm/scale" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_000/layer_000/rms_norm/scale/Assign" - op: "Assign" - input: "decoder/block_000/layer_000/rms_norm/scale" - input: "decoder/block_000/layer_000/rms_norm/scale/Initializer/ones" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_000/rms_norm/scale/read" - op: "Identity" - input: "decoder/block_000/layer_000/rms_norm/scale" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/q/Initializer/random_normal/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200\000\000\000" - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/q/Initializer/random_normal/mean" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 0 - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/q/Initializer/random_normal/stddev" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 15541 - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/q/Initializer/random_normal/RandomStandardNormal" - op: "RandomStandardNormal" - input: "decoder/block_000/layer_000/SelfAttention/q/Initializer/random_normal/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/q/Initializer/random_normal/mul" - op: "Mul" - input: "decoder/block_000/layer_000/SelfAttention/q/Initializer/random_normal/RandomStandardNormal" - input: "decoder/block_000/layer_000/SelfAttention/q/Initializer/random_normal/stddev" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/q/Initializer/random_normal" - op: "Add" - input: "decoder/block_000/layer_000/SelfAttention/q/Initializer/random_normal/mul" - input: "decoder/block_000/layer_000/SelfAttention/q/Initializer/random_normal/mean" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/q" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/q/Assign" - op: "Assign" - input: "decoder/block_000/layer_000/SelfAttention/q" - input: "decoder/block_000/layer_000/SelfAttention/q/Initializer/random_normal" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/q/read" - op: "Identity" - input: "decoder/block_000/layer_000/SelfAttention/q" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/k/Initializer/random_normal/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200\000\000\000" - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/k/Initializer/random_normal/mean" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 0 - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/k/Initializer/random_normal/stddev" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 15925 - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/k/Initializer/random_normal/RandomStandardNormal" - op: "RandomStandardNormal" - input: "decoder/block_000/layer_000/SelfAttention/k/Initializer/random_normal/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/k/Initializer/random_normal/mul" - op: "Mul" - input: "decoder/block_000/layer_000/SelfAttention/k/Initializer/random_normal/RandomStandardNormal" - input: "decoder/block_000/layer_000/SelfAttention/k/Initializer/random_normal/stddev" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/k/Initializer/random_normal" - op: "Add" - input: "decoder/block_000/layer_000/SelfAttention/k/Initializer/random_normal/mul" - input: "decoder/block_000/layer_000/SelfAttention/k/Initializer/random_normal/mean" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/k" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/k/Assign" - op: "Assign" - input: "decoder/block_000/layer_000/SelfAttention/k" - input: "decoder/block_000/layer_000/SelfAttention/k/Initializer/random_normal" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/k/read" - op: "Identity" - input: "decoder/block_000/layer_000/SelfAttention/k" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/v/Initializer/random_normal/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200\000\000\000" - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/v/Initializer/random_normal/mean" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 0 - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/v/Initializer/random_normal/stddev" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 15925 - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/v/Initializer/random_normal/RandomStandardNormal" - op: "RandomStandardNormal" - input: "decoder/block_000/layer_000/SelfAttention/v/Initializer/random_normal/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/v/Initializer/random_normal/mul" - op: "Mul" - input: "decoder/block_000/layer_000/SelfAttention/v/Initializer/random_normal/RandomStandardNormal" - input: "decoder/block_000/layer_000/SelfAttention/v/Initializer/random_normal/stddev" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/v/Initializer/random_normal" - op: "Add" - input: "decoder/block_000/layer_000/SelfAttention/v/Initializer/random_normal/mul" - input: "decoder/block_000/layer_000/SelfAttention/v/Initializer/random_normal/mean" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/v" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/v/Assign" - op: "Assign" - input: "decoder/block_000/layer_000/SelfAttention/v" - input: "decoder/block_000/layer_000/SelfAttention/v/Initializer/random_normal" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/v/read" - op: "Identity" - input: "decoder/block_000/layer_000/SelfAttention/v" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/o/Initializer/random_normal/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\200\000\000\000 \000\000\000" - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/o/Initializer/random_normal/mean" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 0 - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/o/Initializer/random_normal/stddev" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 15797 - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/o/Initializer/random_normal/RandomStandardNormal" - op: "RandomStandardNormal" - input: "decoder/block_000/layer_000/SelfAttention/o/Initializer/random_normal/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/o/Initializer/random_normal/mul" - op: "Mul" - input: "decoder/block_000/layer_000/SelfAttention/o/Initializer/random_normal/RandomStandardNormal" - input: "decoder/block_000/layer_000/SelfAttention/o/Initializer/random_normal/stddev" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/o/Initializer/random_normal" - op: "Add" - input: "decoder/block_000/layer_000/SelfAttention/o/Initializer/random_normal/mul" - input: "decoder/block_000/layer_000/SelfAttention/o/Initializer/random_normal/mean" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/o" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/o/Assign" - op: "Assign" - input: "decoder/block_000/layer_000/SelfAttention/o" - input: "decoder/block_000/layer_000/SelfAttention/o/Initializer/random_normal" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/o/read" - op: "Identity" - input: "decoder/block_000/layer_000/SelfAttention/o" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/maximum/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 16 - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/minimum/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 31 - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/Initializer/random_uniform/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/relative_attention_bias" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\002\000\000\000 \000\000\000" - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/Initializer/random_uniform/min" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/relative_attention_bias" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 48855 - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/Initializer/random_uniform/max" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/relative_attention_bias" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 16087 - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/Initializer/random_uniform/RandomUniform" - op: "RandomUniform" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/Initializer/random_uniform/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/relative_attention_bias" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/Initializer/random_uniform/sub" - op: "Sub" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/Initializer/random_uniform/max" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/relative_attention_bias" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/Initializer/random_uniform/mul" - op: "Mul" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/Initializer/random_uniform/RandomUniform" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/Initializer/random_uniform/sub" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/relative_attention_bias" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/Initializer/random_uniform" - op: "Add" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/Initializer/random_uniform/mul" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/relative_attention_bias" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/relative_attention_bias" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/Assign" - op: "Assign" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/Initializer/random_uniform" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/relative_attention_bias" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/read" - op: "Identity" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/relative_attention_bias" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/rms_norm/scale/Initializer/ones" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - dim { - size: 32 - } - } - half_val: 16256 - } - } - } -} -node { - name: "decoder/block_000/layer_001/rms_norm/scale" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_000/layer_001/rms_norm/scale/Assign" - op: "Assign" - input: "decoder/block_000/layer_001/rms_norm/scale" - input: "decoder/block_000/layer_001/rms_norm/scale/Initializer/ones" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_001/rms_norm/scale/read" - op: "Identity" - input: "decoder/block_000/layer_001/rms_norm/scale" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/q/Initializer/random_normal/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200\000\000\000" - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/q/Initializer/random_normal/mean" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 0 - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/q/Initializer/random_normal/stddev" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 15541 - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/q/Initializer/random_normal/RandomStandardNormal" - op: "RandomStandardNormal" - input: "decoder/block_000/layer_001/EncDecAttention/q/Initializer/random_normal/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/q/Initializer/random_normal/mul" - op: "Mul" - input: "decoder/block_000/layer_001/EncDecAttention/q/Initializer/random_normal/RandomStandardNormal" - input: "decoder/block_000/layer_001/EncDecAttention/q/Initializer/random_normal/stddev" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/q/Initializer/random_normal" - op: "Add" - input: "decoder/block_000/layer_001/EncDecAttention/q/Initializer/random_normal/mul" - input: "decoder/block_000/layer_001/EncDecAttention/q/Initializer/random_normal/mean" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/q" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/q/Assign" - op: "Assign" - input: "decoder/block_000/layer_001/EncDecAttention/q" - input: "decoder/block_000/layer_001/EncDecAttention/q/Initializer/random_normal" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/q/read" - op: "Identity" - input: "decoder/block_000/layer_001/EncDecAttention/q" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/k/Initializer/random_normal/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200\000\000\000" - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/k/Initializer/random_normal/mean" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 0 - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/k/Initializer/random_normal/stddev" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 15925 - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/k/Initializer/random_normal/RandomStandardNormal" - op: "RandomStandardNormal" - input: "decoder/block_000/layer_001/EncDecAttention/k/Initializer/random_normal/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/k/Initializer/random_normal/mul" - op: "Mul" - input: "decoder/block_000/layer_001/EncDecAttention/k/Initializer/random_normal/RandomStandardNormal" - input: "decoder/block_000/layer_001/EncDecAttention/k/Initializer/random_normal/stddev" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/k/Initializer/random_normal" - op: "Add" - input: "decoder/block_000/layer_001/EncDecAttention/k/Initializer/random_normal/mul" - input: "decoder/block_000/layer_001/EncDecAttention/k/Initializer/random_normal/mean" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/k" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/k/Assign" - op: "Assign" - input: "decoder/block_000/layer_001/EncDecAttention/k" - input: "decoder/block_000/layer_001/EncDecAttention/k/Initializer/random_normal" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/k/read" - op: "Identity" - input: "decoder/block_000/layer_001/EncDecAttention/k" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/v/Initializer/random_normal/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200\000\000\000" - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/v/Initializer/random_normal/mean" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 0 - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/v/Initializer/random_normal/stddev" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 15925 - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/v/Initializer/random_normal/RandomStandardNormal" - op: "RandomStandardNormal" - input: "decoder/block_000/layer_001/EncDecAttention/v/Initializer/random_normal/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/v/Initializer/random_normal/mul" - op: "Mul" - input: "decoder/block_000/layer_001/EncDecAttention/v/Initializer/random_normal/RandomStandardNormal" - input: "decoder/block_000/layer_001/EncDecAttention/v/Initializer/random_normal/stddev" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/v/Initializer/random_normal" - op: "Add" - input: "decoder/block_000/layer_001/EncDecAttention/v/Initializer/random_normal/mul" - input: "decoder/block_000/layer_001/EncDecAttention/v/Initializer/random_normal/mean" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/v" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/v/Assign" - op: "Assign" - input: "decoder/block_000/layer_001/EncDecAttention/v" - input: "decoder/block_000/layer_001/EncDecAttention/v/Initializer/random_normal" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/v/read" - op: "Identity" - input: "decoder/block_000/layer_001/EncDecAttention/v" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/o/Initializer/random_normal/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\200\000\000\000 \000\000\000" - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/o/Initializer/random_normal/mean" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 0 - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/o/Initializer/random_normal/stddev" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 15797 - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/o/Initializer/random_normal/RandomStandardNormal" - op: "RandomStandardNormal" - input: "decoder/block_000/layer_001/EncDecAttention/o/Initializer/random_normal/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/o/Initializer/random_normal/mul" - op: "Mul" - input: "decoder/block_000/layer_001/EncDecAttention/o/Initializer/random_normal/RandomStandardNormal" - input: "decoder/block_000/layer_001/EncDecAttention/o/Initializer/random_normal/stddev" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/o/Initializer/random_normal" - op: "Add" - input: "decoder/block_000/layer_001/EncDecAttention/o/Initializer/random_normal/mul" - input: "decoder/block_000/layer_001/EncDecAttention/o/Initializer/random_normal/mean" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/o" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/o/Assign" - op: "Assign" - input: "decoder/block_000/layer_001/EncDecAttention/o" - input: "decoder/block_000/layer_001/EncDecAttention/o/Initializer/random_normal" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/o/read" - op: "Identity" - input: "decoder/block_000/layer_001/EncDecAttention/o" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/rms_norm/scale/Initializer/ones" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - dim { - size: 32 - } - } - half_val: 16256 - } - } - } -} -node { - name: "decoder/block_000/layer_002/rms_norm/scale" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_000/layer_002/rms_norm/scale/Assign" - op: "Assign" - input: "decoder/block_000/layer_002/rms_norm/scale" - input: "decoder/block_000/layer_002/rms_norm/scale/Initializer/ones" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_002/rms_norm/scale/read" - op: "Identity" - input: "decoder/block_000/layer_002/rms_norm/scale" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/Initializer/truncated_normal/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wi_0/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000@\000\000\000" - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/Initializer/truncated_normal/mean" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wi_0/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 0 - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/Initializer/truncated_normal/stddev" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wi_0/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 15925 - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/Initializer/truncated_normal/TruncatedNormal" - op: "TruncatedNormal" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/Initializer/truncated_normal/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wi_0/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/Initializer/truncated_normal/mul" - op: "Mul" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/Initializer/truncated_normal/TruncatedNormal" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/Initializer/truncated_normal/stddev" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wi_0/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/Initializer/truncated_normal" - op: "Add" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/Initializer/truncated_normal/mul" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/Initializer/truncated_normal/mean" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wi_0/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wi_0/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/Assign" - op: "Assign" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/Initializer/truncated_normal" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wi_0/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/read" - op: "Identity" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wi_0/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/Initializer/truncated_normal/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wi_1/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000@\000\000\000" - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/Initializer/truncated_normal/mean" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wi_1/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 0 - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/Initializer/truncated_normal/stddev" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wi_1/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 15925 - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/Initializer/truncated_normal/TruncatedNormal" - op: "TruncatedNormal" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/Initializer/truncated_normal/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wi_1/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/Initializer/truncated_normal/mul" - op: "Mul" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/Initializer/truncated_normal/TruncatedNormal" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/Initializer/truncated_normal/stddev" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wi_1/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/Initializer/truncated_normal" - op: "Add" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/Initializer/truncated_normal/mul" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/Initializer/truncated_normal/mean" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wi_1/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wi_1/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/Assign" - op: "Assign" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/Initializer/truncated_normal" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wi_1/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/read" - op: "Identity" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wi_1/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/Initializer/truncated_normal/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wo/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "@\000\000\000 \000\000\000" - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/Initializer/truncated_normal/mean" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wo/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 0 - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/Initializer/truncated_normal/stddev" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wo/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 15872 - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/Initializer/truncated_normal/TruncatedNormal" - op: "TruncatedNormal" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/Initializer/truncated_normal/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wo/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/Initializer/truncated_normal/mul" - op: "Mul" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/Initializer/truncated_normal/TruncatedNormal" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/Initializer/truncated_normal/stddev" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wo/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/Initializer/truncated_normal" - op: "Add" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/Initializer/truncated_normal/mul" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/Initializer/truncated_normal/mean" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wo/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wo/kernel" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wo/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/Assign" - op: "Assign" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/Initializer/truncated_normal" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wo/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/read" - op: "Identity" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wo/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/rms_norm/scale/Initializer/ones" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - dim { - size: 32 - } - } - half_val: 16256 - } - } - } -} -node { - name: "decoder/block_001/layer_000/rms_norm/scale" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_001/layer_000/rms_norm/scale/Assign" - op: "Assign" - input: "decoder/block_001/layer_000/rms_norm/scale" - input: "decoder/block_001/layer_000/rms_norm/scale/Initializer/ones" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_001/layer_000/rms_norm/scale/read" - op: "Identity" - input: "decoder/block_001/layer_000/rms_norm/scale" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/q/Initializer/random_normal/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200\000\000\000" - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/q/Initializer/random_normal/mean" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 0 - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/q/Initializer/random_normal/stddev" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 15541 - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/q/Initializer/random_normal/RandomStandardNormal" - op: "RandomStandardNormal" - input: "decoder/block_001/layer_000/SelfAttention/q/Initializer/random_normal/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/q/Initializer/random_normal/mul" - op: "Mul" - input: "decoder/block_001/layer_000/SelfAttention/q/Initializer/random_normal/RandomStandardNormal" - input: "decoder/block_001/layer_000/SelfAttention/q/Initializer/random_normal/stddev" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/q/Initializer/random_normal" - op: "Add" - input: "decoder/block_001/layer_000/SelfAttention/q/Initializer/random_normal/mul" - input: "decoder/block_001/layer_000/SelfAttention/q/Initializer/random_normal/mean" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/q" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/q/Assign" - op: "Assign" - input: "decoder/block_001/layer_000/SelfAttention/q" - input: "decoder/block_001/layer_000/SelfAttention/q/Initializer/random_normal" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/q/read" - op: "Identity" - input: "decoder/block_001/layer_000/SelfAttention/q" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/k/Initializer/random_normal/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200\000\000\000" - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/k/Initializer/random_normal/mean" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 0 - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/k/Initializer/random_normal/stddev" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 15925 - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/k/Initializer/random_normal/RandomStandardNormal" - op: "RandomStandardNormal" - input: "decoder/block_001/layer_000/SelfAttention/k/Initializer/random_normal/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/k/Initializer/random_normal/mul" - op: "Mul" - input: "decoder/block_001/layer_000/SelfAttention/k/Initializer/random_normal/RandomStandardNormal" - input: "decoder/block_001/layer_000/SelfAttention/k/Initializer/random_normal/stddev" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/k/Initializer/random_normal" - op: "Add" - input: "decoder/block_001/layer_000/SelfAttention/k/Initializer/random_normal/mul" - input: "decoder/block_001/layer_000/SelfAttention/k/Initializer/random_normal/mean" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/k" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/k/Assign" - op: "Assign" - input: "decoder/block_001/layer_000/SelfAttention/k" - input: "decoder/block_001/layer_000/SelfAttention/k/Initializer/random_normal" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/k/read" - op: "Identity" - input: "decoder/block_001/layer_000/SelfAttention/k" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/v/Initializer/random_normal/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200\000\000\000" - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/v/Initializer/random_normal/mean" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 0 - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/v/Initializer/random_normal/stddev" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 15925 - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/v/Initializer/random_normal/RandomStandardNormal" - op: "RandomStandardNormal" - input: "decoder/block_001/layer_000/SelfAttention/v/Initializer/random_normal/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/v/Initializer/random_normal/mul" - op: "Mul" - input: "decoder/block_001/layer_000/SelfAttention/v/Initializer/random_normal/RandomStandardNormal" - input: "decoder/block_001/layer_000/SelfAttention/v/Initializer/random_normal/stddev" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/v/Initializer/random_normal" - op: "Add" - input: "decoder/block_001/layer_000/SelfAttention/v/Initializer/random_normal/mul" - input: "decoder/block_001/layer_000/SelfAttention/v/Initializer/random_normal/mean" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/v" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/v/Assign" - op: "Assign" - input: "decoder/block_001/layer_000/SelfAttention/v" - input: "decoder/block_001/layer_000/SelfAttention/v/Initializer/random_normal" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/v/read" - op: "Identity" - input: "decoder/block_001/layer_000/SelfAttention/v" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/o/Initializer/random_normal/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\200\000\000\000 \000\000\000" - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/o/Initializer/random_normal/mean" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 0 - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/o/Initializer/random_normal/stddev" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 15797 - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/o/Initializer/random_normal/RandomStandardNormal" - op: "RandomStandardNormal" - input: "decoder/block_001/layer_000/SelfAttention/o/Initializer/random_normal/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/o/Initializer/random_normal/mul" - op: "Mul" - input: "decoder/block_001/layer_000/SelfAttention/o/Initializer/random_normal/RandomStandardNormal" - input: "decoder/block_001/layer_000/SelfAttention/o/Initializer/random_normal/stddev" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/o/Initializer/random_normal" - op: "Add" - input: "decoder/block_001/layer_000/SelfAttention/o/Initializer/random_normal/mul" - input: "decoder/block_001/layer_000/SelfAttention/o/Initializer/random_normal/mean" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/o" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/o/Assign" - op: "Assign" - input: "decoder/block_001/layer_000/SelfAttention/o" - input: "decoder/block_001/layer_000/SelfAttention/o/Initializer/random_normal" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/o/read" - op: "Identity" - input: "decoder/block_001/layer_000/SelfAttention/o" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/maximum/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 16 - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/minimum/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 31 - } - } - } -} -node { - name: "decoder/block_001/layer_001/rms_norm/scale/Initializer/ones" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - dim { - size: 32 - } - } - half_val: 16256 - } - } - } -} -node { - name: "decoder/block_001/layer_001/rms_norm/scale" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_001/layer_001/rms_norm/scale/Assign" - op: "Assign" - input: "decoder/block_001/layer_001/rms_norm/scale" - input: "decoder/block_001/layer_001/rms_norm/scale/Initializer/ones" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_001/layer_001/rms_norm/scale/read" - op: "Identity" - input: "decoder/block_001/layer_001/rms_norm/scale" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/q/Initializer/random_normal/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200\000\000\000" - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/q/Initializer/random_normal/mean" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 0 - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/q/Initializer/random_normal/stddev" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 15541 - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/q/Initializer/random_normal/RandomStandardNormal" - op: "RandomStandardNormal" - input: "decoder/block_001/layer_001/EncDecAttention/q/Initializer/random_normal/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/q/Initializer/random_normal/mul" - op: "Mul" - input: "decoder/block_001/layer_001/EncDecAttention/q/Initializer/random_normal/RandomStandardNormal" - input: "decoder/block_001/layer_001/EncDecAttention/q/Initializer/random_normal/stddev" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/q/Initializer/random_normal" - op: "Add" - input: "decoder/block_001/layer_001/EncDecAttention/q/Initializer/random_normal/mul" - input: "decoder/block_001/layer_001/EncDecAttention/q/Initializer/random_normal/mean" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/q" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/q/Assign" - op: "Assign" - input: "decoder/block_001/layer_001/EncDecAttention/q" - input: "decoder/block_001/layer_001/EncDecAttention/q/Initializer/random_normal" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/q/read" - op: "Identity" - input: "decoder/block_001/layer_001/EncDecAttention/q" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/k/Initializer/random_normal/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200\000\000\000" - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/k/Initializer/random_normal/mean" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 0 - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/k/Initializer/random_normal/stddev" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 15925 - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/k/Initializer/random_normal/RandomStandardNormal" - op: "RandomStandardNormal" - input: "decoder/block_001/layer_001/EncDecAttention/k/Initializer/random_normal/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/k/Initializer/random_normal/mul" - op: "Mul" - input: "decoder/block_001/layer_001/EncDecAttention/k/Initializer/random_normal/RandomStandardNormal" - input: "decoder/block_001/layer_001/EncDecAttention/k/Initializer/random_normal/stddev" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/k/Initializer/random_normal" - op: "Add" - input: "decoder/block_001/layer_001/EncDecAttention/k/Initializer/random_normal/mul" - input: "decoder/block_001/layer_001/EncDecAttention/k/Initializer/random_normal/mean" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/k" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/k/Assign" - op: "Assign" - input: "decoder/block_001/layer_001/EncDecAttention/k" - input: "decoder/block_001/layer_001/EncDecAttention/k/Initializer/random_normal" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/k/read" - op: "Identity" - input: "decoder/block_001/layer_001/EncDecAttention/k" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/v/Initializer/random_normal/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200\000\000\000" - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/v/Initializer/random_normal/mean" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 0 - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/v/Initializer/random_normal/stddev" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 15925 - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/v/Initializer/random_normal/RandomStandardNormal" - op: "RandomStandardNormal" - input: "decoder/block_001/layer_001/EncDecAttention/v/Initializer/random_normal/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/v/Initializer/random_normal/mul" - op: "Mul" - input: "decoder/block_001/layer_001/EncDecAttention/v/Initializer/random_normal/RandomStandardNormal" - input: "decoder/block_001/layer_001/EncDecAttention/v/Initializer/random_normal/stddev" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/v/Initializer/random_normal" - op: "Add" - input: "decoder/block_001/layer_001/EncDecAttention/v/Initializer/random_normal/mul" - input: "decoder/block_001/layer_001/EncDecAttention/v/Initializer/random_normal/mean" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/v" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/v/Assign" - op: "Assign" - input: "decoder/block_001/layer_001/EncDecAttention/v" - input: "decoder/block_001/layer_001/EncDecAttention/v/Initializer/random_normal" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/v/read" - op: "Identity" - input: "decoder/block_001/layer_001/EncDecAttention/v" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/o/Initializer/random_normal/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\200\000\000\000 \000\000\000" - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/o/Initializer/random_normal/mean" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 0 - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/o/Initializer/random_normal/stddev" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 15797 - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/o/Initializer/random_normal/RandomStandardNormal" - op: "RandomStandardNormal" - input: "decoder/block_001/layer_001/EncDecAttention/o/Initializer/random_normal/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/o/Initializer/random_normal/mul" - op: "Mul" - input: "decoder/block_001/layer_001/EncDecAttention/o/Initializer/random_normal/RandomStandardNormal" - input: "decoder/block_001/layer_001/EncDecAttention/o/Initializer/random_normal/stddev" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/o/Initializer/random_normal" - op: "Add" - input: "decoder/block_001/layer_001/EncDecAttention/o/Initializer/random_normal/mul" - input: "decoder/block_001/layer_001/EncDecAttention/o/Initializer/random_normal/mean" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/o" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/o/Assign" - op: "Assign" - input: "decoder/block_001/layer_001/EncDecAttention/o" - input: "decoder/block_001/layer_001/EncDecAttention/o/Initializer/random_normal" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/o/read" - op: "Identity" - input: "decoder/block_001/layer_001/EncDecAttention/o" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/rms_norm/scale/Initializer/ones" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - dim { - size: 32 - } - } - half_val: 16256 - } - } - } -} -node { - name: "decoder/block_001/layer_002/rms_norm/scale" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_001/layer_002/rms_norm/scale/Assign" - op: "Assign" - input: "decoder/block_001/layer_002/rms_norm/scale" - input: "decoder/block_001/layer_002/rms_norm/scale/Initializer/ones" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_001/layer_002/rms_norm/scale/read" - op: "Identity" - input: "decoder/block_001/layer_002/rms_norm/scale" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/Initializer/truncated_normal/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wi_0/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000@\000\000\000" - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/Initializer/truncated_normal/mean" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wi_0/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 0 - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/Initializer/truncated_normal/stddev" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wi_0/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 15925 - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/Initializer/truncated_normal/TruncatedNormal" - op: "TruncatedNormal" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/Initializer/truncated_normal/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wi_0/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/Initializer/truncated_normal/mul" - op: "Mul" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/Initializer/truncated_normal/TruncatedNormal" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/Initializer/truncated_normal/stddev" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wi_0/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/Initializer/truncated_normal" - op: "Add" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/Initializer/truncated_normal/mul" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/Initializer/truncated_normal/mean" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wi_0/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wi_0/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/Assign" - op: "Assign" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/Initializer/truncated_normal" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wi_0/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/read" - op: "Identity" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wi_0/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/Initializer/truncated_normal/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wi_1/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000@\000\000\000" - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/Initializer/truncated_normal/mean" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wi_1/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 0 - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/Initializer/truncated_normal/stddev" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wi_1/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 15925 - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/Initializer/truncated_normal/TruncatedNormal" - op: "TruncatedNormal" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/Initializer/truncated_normal/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wi_1/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/Initializer/truncated_normal/mul" - op: "Mul" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/Initializer/truncated_normal/TruncatedNormal" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/Initializer/truncated_normal/stddev" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wi_1/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/Initializer/truncated_normal" - op: "Add" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/Initializer/truncated_normal/mul" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/Initializer/truncated_normal/mean" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wi_1/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wi_1/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/Assign" - op: "Assign" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/Initializer/truncated_normal" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wi_1/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/read" - op: "Identity" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wi_1/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/Initializer/truncated_normal/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wo/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "@\000\000\000 \000\000\000" - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/Initializer/truncated_normal/mean" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wo/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 0 - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/Initializer/truncated_normal/stddev" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wo/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 15872 - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/Initializer/truncated_normal/TruncatedNormal" - op: "TruncatedNormal" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/Initializer/truncated_normal/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wo/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/Initializer/truncated_normal/mul" - op: "Mul" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/Initializer/truncated_normal/TruncatedNormal" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/Initializer/truncated_normal/stddev" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wo/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/Initializer/truncated_normal" - op: "Add" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/Initializer/truncated_normal/mul" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/Initializer/truncated_normal/mean" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wo/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wo/kernel" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wo/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/Assign" - op: "Assign" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/Initializer/truncated_normal" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wo/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/read" - op: "Identity" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wo/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/rms_norm/scale/Initializer/ones" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - dim { - size: 32 - } - } - half_val: 16256 - } - } - } -} -node { - name: "decoder/rms_norm/scale" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/rms_norm/scale/Assign" - op: "Assign" - input: "decoder/rms_norm/scale" - input: "decoder/rms_norm/scale/Initializer/ones" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/rms_norm/scale/read" - op: "Identity" - input: "decoder/rms_norm/scale" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "decoder/logits/kernel/Initializer/truncated_normal/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/logits/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200}\000\000" - } - } - } -} -node { - name: "decoder/logits/kernel/Initializer/truncated_normal/mean" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/logits/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 0 - } - } - } -} -node { - name: "decoder/logits/kernel/Initializer/truncated_normal/stddev" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/logits/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BFLOAT16 - tensor_shape { - } - half_val: 15925 - } - } - } -} -node { - name: "decoder/logits/kernel/Initializer/truncated_normal/TruncatedNormal" - op: "TruncatedNormal" - input: "decoder/logits/kernel/Initializer/truncated_normal/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/logits/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 32128 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "decoder/logits/kernel/Initializer/truncated_normal/mul" - op: "Mul" - input: "decoder/logits/kernel/Initializer/truncated_normal/TruncatedNormal" - input: "decoder/logits/kernel/Initializer/truncated_normal/stddev" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/logits/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 32128 - } - } - } - } - } -} -node { - name: "decoder/logits/kernel/Initializer/truncated_normal" - op: "Add" - input: "decoder/logits/kernel/Initializer/truncated_normal/mul" - input: "decoder/logits/kernel/Initializer/truncated_normal/mean" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/logits/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 32128 - } - } - } - } - } -} -node { - name: "decoder/logits/kernel" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/logits/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 32128 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 32128 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/logits/kernel/Assign" - op: "Assign" - input: "decoder/logits/kernel" - input: "decoder/logits/kernel/Initializer/truncated_normal" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/logits/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 32128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/logits/kernel/read" - op: "Identity" - input: "decoder/logits/kernel" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/logits/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 32128 - } - } - } - } - } -} -node { - name: "decoder/maximum/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "decoder/Const_1" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_0/tanh/gradients/sub/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_0/tanh/gradients/sub/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_0/tanh/gradients/sub/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_0/tanh/gradients/sub/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "ReadVariableOp" - op: "ReadVariableOp" - input: "global_step" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT64 - } - } -} -node { - name: "sub/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: 0 - } - } - } -} -node { - name: "sub" - op: "Sub" - input: "ReadVariableOp" - input: "sub/y" - attr { - key: "T" - value { - type: DT_INT64 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "Cast" - op: "Cast" - input: "sub" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_INT64 - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "Maximum/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 10000.0 - } - } - } -} -node { - name: "Maximum" - op: "Maximum" - input: "Cast" - input: "Maximum/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "Rsqrt" - op: "Rsqrt" - input: "Maximum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "mul_2/x" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "mul_2" - op: "Mul" - input: "mul_2/x" - input: "Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "ReadVariableOp_1" - op: "ReadVariableOp" - input: "global_step" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT64 - } - } -} -node { - name: "sub_1/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: 0 - } - } - } -} -node { - name: "sub_1" - op: "Sub" - input: "ReadVariableOp_1" - input: "sub_1/y" - attr { - key: "T" - value { - type: DT_INT64 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "Cast_1" - op: "Cast" - input: "sub_1" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_INT64 - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "sub_2/x" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 100.0 - } - } - } -} -node { - name: "sub_2" - op: "Sub" - input: "sub_2/x" - input: "Cast_1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "truediv/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 10.0 - } - } - } -} -node { - name: "truediv" - op: "RealDiv" - input: "sub_2" - input: "truediv/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "Minimum/x" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "Minimum" - op: "Minimum" - input: "Minimum/x" - input: "truediv" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "mul_3" - op: "Mul" - input: "mul_2" - input: "Minimum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "ReadVariableOp_2" - op: "ReadVariableOp" - input: "global_step" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT64 - } - } -} -node { - name: "sub_3/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: 0 - } - } - } -} -node { - name: "sub_3" - op: "Sub" - input: "ReadVariableOp_2" - input: "sub_3/y" - attr { - key: "T" - value { - type: DT_INT64 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "mul_4/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "mul_4" - op: "Mul" - input: "mul_3" - input: "mul_4/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "learning_rate/tags" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - } - string_val: "learning_rate" - } - } - } -} -node { - name: "learning_rate" - op: "ScalarSummary" - input: "learning_rate/tags" - input: "mul_4" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "ReadVariableOp_3" - op: "ReadVariableOp" - input: "global_step" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT64 - } - } -} -node { - name: "sub_4/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: 0 - } - } - } -} -node { - name: "sub_4" - op: "Sub" - input: "ReadVariableOp_3" - input: "sub_4/y" - attr { - key: "T" - value { - type: DT_INT64 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "Cast_2" - op: "Cast" - input: "sub_4" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_INT64 - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "add/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "add" - op: "AddV2" - input: "Cast_2" - input: "add/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "Pow/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: -0.800000011920929 - } - } - } -} -node { - name: "Pow" - op: "Pow" - input: "add" - input: "Pow/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "sub_5/x" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "sub_5" - op: "Sub" - input: "sub_5/x" - input: "Pow" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "shared/embedding_slot_v/Initializer/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@shared/embedding_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\200}\000\000 \000\000\000" - } - } - } -} -node { - name: "shared/embedding_slot_v/Initializer/zeros/Const" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@shared/embedding_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "shared/embedding_slot_v/Initializer/zeros" - op: "Fill" - input: "shared/embedding_slot_v/Initializer/zeros/shape_as_tensor" - input: "shared/embedding_slot_v/Initializer/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@shared/embedding_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "shared/embedding_slot_v" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@shared/embedding_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32128 - } - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "shared/embedding_slot_v/Assign" - op: "Assign" - input: "shared/embedding_slot_v" - input: "shared/embedding_slot_v/Initializer/zeros" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@shared/embedding_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "shared/embedding_slot_v/read" - op: "Identity" - input: "shared/embedding_slot_v" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@shared/embedding_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "shared/embedding/adafactor/maximum/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0010000000474974513 - } - } - } -} -node { - name: "shared/embedding/adafactor/sub/x" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "shared/embedding/adafactor/sub" - op: "Sub" - input: "shared/embedding/adafactor/sub/x" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "shared/embedding/adafactor/maximum_1/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "encoder/block_000/layer_000/rms_norm/scale_slot_v/Initializer/zeros" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - dim { - size: 32 - } - } - float_val: 0.0 - } - } - } -} -node { - name: "encoder/block_000/layer_000/rms_norm/scale_slot_v" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "encoder/block_000/layer_000/rms_norm/scale_slot_v/Assign" - op: "Assign" - input: "encoder/block_000/layer_000/rms_norm/scale_slot_v" - input: "encoder/block_000/layer_000/rms_norm/scale_slot_v/Initializer/zeros" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_000/layer_000/rms_norm/scale_slot_v/read" - op: "Identity" - input: "encoder/block_000/layer_000/rms_norm/scale_slot_v" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/rms_norm/scale/adafactor/maximum/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0010000000474974513 - } - } - } -} -node { - name: "encoder/block_000/layer_000/rms_norm/scale/adafactor/sub/x" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "encoder/block_000/layer_000/rms_norm/scale/adafactor/sub" - op: "Sub" - input: "encoder/block_000/layer_000/rms_norm/scale/adafactor/sub/x" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/rms_norm/scale/adafactor/maximum_1/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/q_slot_v/Initializer/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/q_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200\000\000\000" - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/q_slot_v/Initializer/zeros/Const" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/q_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/q_slot_v/Initializer/zeros" - op: "Fill" - input: "encoder/block_000/layer_000/SelfAttention/q_slot_v/Initializer/zeros/shape_as_tensor" - input: "encoder/block_000/layer_000/SelfAttention/q_slot_v/Initializer/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/q_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/q_slot_v" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/q_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/q_slot_v/Assign" - op: "Assign" - input: "encoder/block_000/layer_000/SelfAttention/q_slot_v" - input: "encoder/block_000/layer_000/SelfAttention/q_slot_v/Initializer/zeros" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/q_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/q_slot_v/read" - op: "Identity" - input: "encoder/block_000/layer_000/SelfAttention/q_slot_v" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/q_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/q/adafactor/maximum/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0010000000474974513 - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/q/adafactor/sub/x" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/q/adafactor/sub" - op: "Sub" - input: "encoder/block_000/layer_000/SelfAttention/q/adafactor/sub/x" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/q/adafactor/maximum_1/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/k_slot_v/Initializer/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/k_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200\000\000\000" - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/k_slot_v/Initializer/zeros/Const" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/k_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/k_slot_v/Initializer/zeros" - op: "Fill" - input: "encoder/block_000/layer_000/SelfAttention/k_slot_v/Initializer/zeros/shape_as_tensor" - input: "encoder/block_000/layer_000/SelfAttention/k_slot_v/Initializer/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/k_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/k_slot_v" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/k_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/k_slot_v/Assign" - op: "Assign" - input: "encoder/block_000/layer_000/SelfAttention/k_slot_v" - input: "encoder/block_000/layer_000/SelfAttention/k_slot_v/Initializer/zeros" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/k_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/k_slot_v/read" - op: "Identity" - input: "encoder/block_000/layer_000/SelfAttention/k_slot_v" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/k_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/k/adafactor/maximum/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0010000000474974513 - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/k/adafactor/sub/x" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/k/adafactor/sub" - op: "Sub" - input: "encoder/block_000/layer_000/SelfAttention/k/adafactor/sub/x" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/k/adafactor/maximum_1/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/v_slot_v/Initializer/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/v_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200\000\000\000" - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/v_slot_v/Initializer/zeros/Const" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/v_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/v_slot_v/Initializer/zeros" - op: "Fill" - input: "encoder/block_000/layer_000/SelfAttention/v_slot_v/Initializer/zeros/shape_as_tensor" - input: "encoder/block_000/layer_000/SelfAttention/v_slot_v/Initializer/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/v_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/v_slot_v" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/v_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/v_slot_v/Assign" - op: "Assign" - input: "encoder/block_000/layer_000/SelfAttention/v_slot_v" - input: "encoder/block_000/layer_000/SelfAttention/v_slot_v/Initializer/zeros" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/v_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/v_slot_v/read" - op: "Identity" - input: "encoder/block_000/layer_000/SelfAttention/v_slot_v" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/v_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/v/adafactor/maximum/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0010000000474974513 - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/v/adafactor/sub/x" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/v/adafactor/sub" - op: "Sub" - input: "encoder/block_000/layer_000/SelfAttention/v/adafactor/sub/x" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/v/adafactor/maximum_1/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/o_slot_v/Initializer/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/o_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\200\000\000\000 \000\000\000" - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/o_slot_v/Initializer/zeros/Const" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/o_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/o_slot_v/Initializer/zeros" - op: "Fill" - input: "encoder/block_000/layer_000/SelfAttention/o_slot_v/Initializer/zeros/shape_as_tensor" - input: "encoder/block_000/layer_000/SelfAttention/o_slot_v/Initializer/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/o_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/o_slot_v" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/o_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/o_slot_v/Assign" - op: "Assign" - input: "encoder/block_000/layer_000/SelfAttention/o_slot_v" - input: "encoder/block_000/layer_000/SelfAttention/o_slot_v/Initializer/zeros" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/o_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/o_slot_v/read" - op: "Identity" - input: "encoder/block_000/layer_000/SelfAttention/o_slot_v" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/o_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/o/adafactor/maximum/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0010000000474974513 - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/o/adafactor/sub/x" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/o/adafactor/sub" - op: "Sub" - input: "encoder/block_000/layer_000/SelfAttention/o/adafactor/sub/x" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/o/adafactor/maximum_1/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias_slot_v/Initializer/zeros" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/relative_attention_bias_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - float_val: 0.0 - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias_slot_v" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/relative_attention_bias_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias_slot_v/Assign" - op: "Assign" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias_slot_v" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias_slot_v/Initializer/zeros" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/relative_attention_bias_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias_slot_v/read" - op: "Identity" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias_slot_v" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/relative_attention_bias_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/maximum/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0010000000474974513 - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/sub/x" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/sub" - op: "Sub" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/sub/x" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/maximum_1/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "encoder/block_000/layer_001/rms_norm/scale_slot_v/Initializer/zeros" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - dim { - size: 32 - } - } - float_val: 0.0 - } - } - } -} -node { - name: "encoder/block_000/layer_001/rms_norm/scale_slot_v" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "encoder/block_000/layer_001/rms_norm/scale_slot_v/Assign" - op: "Assign" - input: "encoder/block_000/layer_001/rms_norm/scale_slot_v" - input: "encoder/block_000/layer_001/rms_norm/scale_slot_v/Initializer/zeros" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_000/layer_001/rms_norm/scale_slot_v/read" - op: "Identity" - input: "encoder/block_000/layer_001/rms_norm/scale_slot_v" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/rms_norm/scale/adafactor/maximum/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0010000000474974513 - } - } - } -} -node { - name: "encoder/block_000/layer_001/rms_norm/scale/adafactor/sub/x" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "encoder/block_000/layer_001/rms_norm/scale/adafactor/sub" - op: "Sub" - input: "encoder/block_000/layer_001/rms_norm/scale/adafactor/sub/x" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/rms_norm/scale/adafactor/maximum_1/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slot_v/Initializer/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000@\000\000\000" - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slot_v/Initializer/zeros/Const" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slot_v/Initializer/zeros" - op: "Fill" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slot_v/Initializer/zeros/shape_as_tensor" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slot_v/Initializer/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slot_v" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slot_v/Assign" - op: "Assign" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slot_v" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slot_v/Initializer/zeros" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slot_v/read" - op: "Identity" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slot_v" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/adafactor/maximum/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0010000000474974513 - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/adafactor/sub/x" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/adafactor/sub" - op: "Sub" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/adafactor/sub/x" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/adafactor/maximum_1/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slot_v/Initializer/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000@\000\000\000" - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slot_v/Initializer/zeros/Const" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slot_v/Initializer/zeros" - op: "Fill" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slot_v/Initializer/zeros/shape_as_tensor" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slot_v/Initializer/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slot_v" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slot_v/Assign" - op: "Assign" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slot_v" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slot_v/Initializer/zeros" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slot_v/read" - op: "Identity" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slot_v" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/adafactor/maximum/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0010000000474974513 - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/adafactor/sub/x" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/adafactor/sub" - op: "Sub" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/adafactor/sub/x" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/adafactor/maximum_1/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wo/kernel_slot_v/Initializer/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wo/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "@\000\000\000 \000\000\000" - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wo/kernel_slot_v/Initializer/zeros/Const" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wo/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wo/kernel_slot_v/Initializer/zeros" - op: "Fill" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel_slot_v/Initializer/zeros/shape_as_tensor" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel_slot_v/Initializer/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wo/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wo/kernel_slot_v" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wo/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wo/kernel_slot_v/Assign" - op: "Assign" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel_slot_v" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel_slot_v/Initializer/zeros" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wo/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wo/kernel_slot_v/read" - op: "Identity" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel_slot_v" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wo/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/adafactor/maximum/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0010000000474974513 - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/adafactor/sub/x" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/adafactor/sub" - op: "Sub" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/adafactor/sub/x" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/adafactor/maximum_1/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "encoder/block_001/layer_000/rms_norm/scale_slot_v/Initializer/zeros" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - dim { - size: 32 - } - } - float_val: 0.0 - } - } - } -} -node { - name: "encoder/block_001/layer_000/rms_norm/scale_slot_v" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "encoder/block_001/layer_000/rms_norm/scale_slot_v/Assign" - op: "Assign" - input: "encoder/block_001/layer_000/rms_norm/scale_slot_v" - input: "encoder/block_001/layer_000/rms_norm/scale_slot_v/Initializer/zeros" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_001/layer_000/rms_norm/scale_slot_v/read" - op: "Identity" - input: "encoder/block_001/layer_000/rms_norm/scale_slot_v" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/rms_norm/scale/adafactor/maximum/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0010000000474974513 - } - } - } -} -node { - name: "encoder/block_001/layer_000/rms_norm/scale/adafactor/sub/x" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "encoder/block_001/layer_000/rms_norm/scale/adafactor/sub" - op: "Sub" - input: "encoder/block_001/layer_000/rms_norm/scale/adafactor/sub/x" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/rms_norm/scale/adafactor/maximum_1/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/q_slot_v/Initializer/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/q_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200\000\000\000" - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/q_slot_v/Initializer/zeros/Const" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/q_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/q_slot_v/Initializer/zeros" - op: "Fill" - input: "encoder/block_001/layer_000/SelfAttention/q_slot_v/Initializer/zeros/shape_as_tensor" - input: "encoder/block_001/layer_000/SelfAttention/q_slot_v/Initializer/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/q_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/q_slot_v" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/q_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/q_slot_v/Assign" - op: "Assign" - input: "encoder/block_001/layer_000/SelfAttention/q_slot_v" - input: "encoder/block_001/layer_000/SelfAttention/q_slot_v/Initializer/zeros" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/q_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/q_slot_v/read" - op: "Identity" - input: "encoder/block_001/layer_000/SelfAttention/q_slot_v" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/q_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/q/adafactor/maximum/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0010000000474974513 - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/q/adafactor/sub/x" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/q/adafactor/sub" - op: "Sub" - input: "encoder/block_001/layer_000/SelfAttention/q/adafactor/sub/x" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/q/adafactor/maximum_1/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/k_slot_v/Initializer/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/k_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200\000\000\000" - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/k_slot_v/Initializer/zeros/Const" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/k_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/k_slot_v/Initializer/zeros" - op: "Fill" - input: "encoder/block_001/layer_000/SelfAttention/k_slot_v/Initializer/zeros/shape_as_tensor" - input: "encoder/block_001/layer_000/SelfAttention/k_slot_v/Initializer/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/k_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/k_slot_v" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/k_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/k_slot_v/Assign" - op: "Assign" - input: "encoder/block_001/layer_000/SelfAttention/k_slot_v" - input: "encoder/block_001/layer_000/SelfAttention/k_slot_v/Initializer/zeros" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/k_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/k_slot_v/read" - op: "Identity" - input: "encoder/block_001/layer_000/SelfAttention/k_slot_v" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/k_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/k/adafactor/maximum/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0010000000474974513 - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/k/adafactor/sub/x" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/k/adafactor/sub" - op: "Sub" - input: "encoder/block_001/layer_000/SelfAttention/k/adafactor/sub/x" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/k/adafactor/maximum_1/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/v_slot_v/Initializer/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/v_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200\000\000\000" - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/v_slot_v/Initializer/zeros/Const" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/v_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/v_slot_v/Initializer/zeros" - op: "Fill" - input: "encoder/block_001/layer_000/SelfAttention/v_slot_v/Initializer/zeros/shape_as_tensor" - input: "encoder/block_001/layer_000/SelfAttention/v_slot_v/Initializer/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/v_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/v_slot_v" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/v_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/v_slot_v/Assign" - op: "Assign" - input: "encoder/block_001/layer_000/SelfAttention/v_slot_v" - input: "encoder/block_001/layer_000/SelfAttention/v_slot_v/Initializer/zeros" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/v_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/v_slot_v/read" - op: "Identity" - input: "encoder/block_001/layer_000/SelfAttention/v_slot_v" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/v_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/v/adafactor/maximum/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0010000000474974513 - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/v/adafactor/sub/x" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/v/adafactor/sub" - op: "Sub" - input: "encoder/block_001/layer_000/SelfAttention/v/adafactor/sub/x" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/v/adafactor/maximum_1/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/o_slot_v/Initializer/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/o_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\200\000\000\000 \000\000\000" - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/o_slot_v/Initializer/zeros/Const" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/o_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/o_slot_v/Initializer/zeros" - op: "Fill" - input: "encoder/block_001/layer_000/SelfAttention/o_slot_v/Initializer/zeros/shape_as_tensor" - input: "encoder/block_001/layer_000/SelfAttention/o_slot_v/Initializer/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/o_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/o_slot_v" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/o_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/o_slot_v/Assign" - op: "Assign" - input: "encoder/block_001/layer_000/SelfAttention/o_slot_v" - input: "encoder/block_001/layer_000/SelfAttention/o_slot_v/Initializer/zeros" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/o_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/o_slot_v/read" - op: "Identity" - input: "encoder/block_001/layer_000/SelfAttention/o_slot_v" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/o_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/o/adafactor/maximum/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0010000000474974513 - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/o/adafactor/sub/x" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/o/adafactor/sub" - op: "Sub" - input: "encoder/block_001/layer_000/SelfAttention/o/adafactor/sub/x" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/o/adafactor/maximum_1/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "encoder/block_001/layer_001/rms_norm/scale_slot_v/Initializer/zeros" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - dim { - size: 32 - } - } - float_val: 0.0 - } - } - } -} -node { - name: "encoder/block_001/layer_001/rms_norm/scale_slot_v" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "encoder/block_001/layer_001/rms_norm/scale_slot_v/Assign" - op: "Assign" - input: "encoder/block_001/layer_001/rms_norm/scale_slot_v" - input: "encoder/block_001/layer_001/rms_norm/scale_slot_v/Initializer/zeros" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_001/layer_001/rms_norm/scale_slot_v/read" - op: "Identity" - input: "encoder/block_001/layer_001/rms_norm/scale_slot_v" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/rms_norm/scale/adafactor/maximum/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0010000000474974513 - } - } - } -} -node { - name: "encoder/block_001/layer_001/rms_norm/scale/adafactor/sub/x" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "encoder/block_001/layer_001/rms_norm/scale/adafactor/sub" - op: "Sub" - input: "encoder/block_001/layer_001/rms_norm/scale/adafactor/sub/x" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/rms_norm/scale/adafactor/maximum_1/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slot_v/Initializer/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000@\000\000\000" - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slot_v/Initializer/zeros/Const" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slot_v/Initializer/zeros" - op: "Fill" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slot_v/Initializer/zeros/shape_as_tensor" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slot_v/Initializer/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slot_v" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slot_v/Assign" - op: "Assign" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slot_v" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slot_v/Initializer/zeros" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slot_v/read" - op: "Identity" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slot_v" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/adafactor/maximum/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0010000000474974513 - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/adafactor/sub/x" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/adafactor/sub" - op: "Sub" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/adafactor/sub/x" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/adafactor/maximum_1/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slot_v/Initializer/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000@\000\000\000" - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slot_v/Initializer/zeros/Const" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slot_v/Initializer/zeros" - op: "Fill" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slot_v/Initializer/zeros/shape_as_tensor" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slot_v/Initializer/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slot_v" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slot_v/Assign" - op: "Assign" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slot_v" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slot_v/Initializer/zeros" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slot_v/read" - op: "Identity" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slot_v" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/adafactor/maximum/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0010000000474974513 - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/adafactor/sub/x" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/adafactor/sub" - op: "Sub" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/adafactor/sub/x" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/adafactor/maximum_1/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wo/kernel_slot_v/Initializer/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wo/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "@\000\000\000 \000\000\000" - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wo/kernel_slot_v/Initializer/zeros/Const" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wo/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wo/kernel_slot_v/Initializer/zeros" - op: "Fill" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel_slot_v/Initializer/zeros/shape_as_tensor" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel_slot_v/Initializer/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wo/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wo/kernel_slot_v" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wo/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wo/kernel_slot_v/Assign" - op: "Assign" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel_slot_v" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel_slot_v/Initializer/zeros" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wo/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wo/kernel_slot_v/read" - op: "Identity" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel_slot_v" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wo/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/adafactor/maximum/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0010000000474974513 - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/adafactor/sub/x" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/adafactor/sub" - op: "Sub" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/adafactor/sub/x" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/adafactor/maximum_1/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "encoder/rms_norm/scale_slot_v/Initializer/zeros" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - dim { - size: 32 - } - } - float_val: 0.0 - } - } - } -} -node { - name: "encoder/rms_norm/scale_slot_v" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "encoder/rms_norm/scale_slot_v/Assign" - op: "Assign" - input: "encoder/rms_norm/scale_slot_v" - input: "encoder/rms_norm/scale_slot_v/Initializer/zeros" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/rms_norm/scale_slot_v/read" - op: "Identity" - input: "encoder/rms_norm/scale_slot_v" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/rms_norm/scale/adafactor/maximum/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0010000000474974513 - } - } - } -} -node { - name: "encoder/rms_norm/scale/adafactor/sub/x" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "encoder/rms_norm/scale/adafactor/sub" - op: "Sub" - input: "encoder/rms_norm/scale/adafactor/sub/x" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/rms_norm/scale/adafactor/maximum_1/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_000/layer_000/rms_norm/scale_slot_v/Initializer/zeros" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - dim { - size: 32 - } - } - float_val: 0.0 - } - } - } -} -node { - name: "decoder/block_000/layer_000/rms_norm/scale_slot_v" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_000/layer_000/rms_norm/scale_slot_v/Assign" - op: "Assign" - input: "decoder/block_000/layer_000/rms_norm/scale_slot_v" - input: "decoder/block_000/layer_000/rms_norm/scale_slot_v/Initializer/zeros" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_000/rms_norm/scale_slot_v/read" - op: "Identity" - input: "decoder/block_000/layer_000/rms_norm/scale_slot_v" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/rms_norm/scale/adafactor/maximum/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0010000000474974513 - } - } - } -} -node { - name: "decoder/block_000/layer_000/rms_norm/scale/adafactor/sub/x" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_000/layer_000/rms_norm/scale/adafactor/sub" - op: "Sub" - input: "decoder/block_000/layer_000/rms_norm/scale/adafactor/sub/x" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/rms_norm/scale/adafactor/maximum_1/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/q_slot_v/Initializer/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/q_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200\000\000\000" - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/q_slot_v/Initializer/zeros/Const" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/q_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/q_slot_v/Initializer/zeros" - op: "Fill" - input: "decoder/block_000/layer_000/SelfAttention/q_slot_v/Initializer/zeros/shape_as_tensor" - input: "decoder/block_000/layer_000/SelfAttention/q_slot_v/Initializer/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/q_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/q_slot_v" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/q_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/q_slot_v/Assign" - op: "Assign" - input: "decoder/block_000/layer_000/SelfAttention/q_slot_v" - input: "decoder/block_000/layer_000/SelfAttention/q_slot_v/Initializer/zeros" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/q_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/q_slot_v/read" - op: "Identity" - input: "decoder/block_000/layer_000/SelfAttention/q_slot_v" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/q_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/q/adafactor/maximum/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0010000000474974513 - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/q/adafactor/sub/x" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/q/adafactor/sub" - op: "Sub" - input: "decoder/block_000/layer_000/SelfAttention/q/adafactor/sub/x" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/q/adafactor/maximum_1/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/k_slot_v/Initializer/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/k_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200\000\000\000" - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/k_slot_v/Initializer/zeros/Const" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/k_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/k_slot_v/Initializer/zeros" - op: "Fill" - input: "decoder/block_000/layer_000/SelfAttention/k_slot_v/Initializer/zeros/shape_as_tensor" - input: "decoder/block_000/layer_000/SelfAttention/k_slot_v/Initializer/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/k_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/k_slot_v" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/k_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/k_slot_v/Assign" - op: "Assign" - input: "decoder/block_000/layer_000/SelfAttention/k_slot_v" - input: "decoder/block_000/layer_000/SelfAttention/k_slot_v/Initializer/zeros" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/k_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/k_slot_v/read" - op: "Identity" - input: "decoder/block_000/layer_000/SelfAttention/k_slot_v" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/k_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/k/adafactor/maximum/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0010000000474974513 - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/k/adafactor/sub/x" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/k/adafactor/sub" - op: "Sub" - input: "decoder/block_000/layer_000/SelfAttention/k/adafactor/sub/x" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/k/adafactor/maximum_1/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/v_slot_v/Initializer/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/v_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200\000\000\000" - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/v_slot_v/Initializer/zeros/Const" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/v_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/v_slot_v/Initializer/zeros" - op: "Fill" - input: "decoder/block_000/layer_000/SelfAttention/v_slot_v/Initializer/zeros/shape_as_tensor" - input: "decoder/block_000/layer_000/SelfAttention/v_slot_v/Initializer/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/v_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/v_slot_v" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/v_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/v_slot_v/Assign" - op: "Assign" - input: "decoder/block_000/layer_000/SelfAttention/v_slot_v" - input: "decoder/block_000/layer_000/SelfAttention/v_slot_v/Initializer/zeros" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/v_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/v_slot_v/read" - op: "Identity" - input: "decoder/block_000/layer_000/SelfAttention/v_slot_v" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/v_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/v/adafactor/maximum/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0010000000474974513 - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/v/adafactor/sub/x" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/v/adafactor/sub" - op: "Sub" - input: "decoder/block_000/layer_000/SelfAttention/v/adafactor/sub/x" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/v/adafactor/maximum_1/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/o_slot_v/Initializer/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/o_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\200\000\000\000 \000\000\000" - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/o_slot_v/Initializer/zeros/Const" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/o_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/o_slot_v/Initializer/zeros" - op: "Fill" - input: "decoder/block_000/layer_000/SelfAttention/o_slot_v/Initializer/zeros/shape_as_tensor" - input: "decoder/block_000/layer_000/SelfAttention/o_slot_v/Initializer/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/o_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/o_slot_v" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/o_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/o_slot_v/Assign" - op: "Assign" - input: "decoder/block_000/layer_000/SelfAttention/o_slot_v" - input: "decoder/block_000/layer_000/SelfAttention/o_slot_v/Initializer/zeros" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/o_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/o_slot_v/read" - op: "Identity" - input: "decoder/block_000/layer_000/SelfAttention/o_slot_v" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/o_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/o/adafactor/maximum/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0010000000474974513 - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/o/adafactor/sub/x" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/o/adafactor/sub" - op: "Sub" - input: "decoder/block_000/layer_000/SelfAttention/o/adafactor/sub/x" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/o/adafactor/maximum_1/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slot_v/Initializer/zeros" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - float_val: 0.0 - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slot_v" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slot_v/Assign" - op: "Assign" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slot_v" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slot_v/Initializer/zeros" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slot_v/read" - op: "Identity" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slot_v" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/maximum/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0010000000474974513 - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/sub/x" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/sub" - op: "Sub" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/sub/x" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/maximum_1/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_000/layer_001/rms_norm/scale_slot_v/Initializer/zeros" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - dim { - size: 32 - } - } - float_val: 0.0 - } - } - } -} -node { - name: "decoder/block_000/layer_001/rms_norm/scale_slot_v" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_000/layer_001/rms_norm/scale_slot_v/Assign" - op: "Assign" - input: "decoder/block_000/layer_001/rms_norm/scale_slot_v" - input: "decoder/block_000/layer_001/rms_norm/scale_slot_v/Initializer/zeros" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_001/rms_norm/scale_slot_v/read" - op: "Identity" - input: "decoder/block_000/layer_001/rms_norm/scale_slot_v" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/rms_norm/scale/adafactor/maximum/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0010000000474974513 - } - } - } -} -node { - name: "decoder/block_000/layer_001/rms_norm/scale/adafactor/sub/x" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_000/layer_001/rms_norm/scale/adafactor/sub" - op: "Sub" - input: "decoder/block_000/layer_001/rms_norm/scale/adafactor/sub/x" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/rms_norm/scale/adafactor/maximum_1/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/q_slot_v/Initializer/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/q_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200\000\000\000" - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/q_slot_v/Initializer/zeros/Const" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/q_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/q_slot_v/Initializer/zeros" - op: "Fill" - input: "decoder/block_000/layer_001/EncDecAttention/q_slot_v/Initializer/zeros/shape_as_tensor" - input: "decoder/block_000/layer_001/EncDecAttention/q_slot_v/Initializer/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/q_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/q_slot_v" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/q_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/q_slot_v/Assign" - op: "Assign" - input: "decoder/block_000/layer_001/EncDecAttention/q_slot_v" - input: "decoder/block_000/layer_001/EncDecAttention/q_slot_v/Initializer/zeros" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/q_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/q_slot_v/read" - op: "Identity" - input: "decoder/block_000/layer_001/EncDecAttention/q_slot_v" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/q_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/q/adafactor/maximum/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0010000000474974513 - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/q/adafactor/sub/x" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/q/adafactor/sub" - op: "Sub" - input: "decoder/block_000/layer_001/EncDecAttention/q/adafactor/sub/x" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/q/adafactor/maximum_1/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/k_slot_v/Initializer/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/k_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200\000\000\000" - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/k_slot_v/Initializer/zeros/Const" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/k_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/k_slot_v/Initializer/zeros" - op: "Fill" - input: "decoder/block_000/layer_001/EncDecAttention/k_slot_v/Initializer/zeros/shape_as_tensor" - input: "decoder/block_000/layer_001/EncDecAttention/k_slot_v/Initializer/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/k_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/k_slot_v" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/k_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/k_slot_v/Assign" - op: "Assign" - input: "decoder/block_000/layer_001/EncDecAttention/k_slot_v" - input: "decoder/block_000/layer_001/EncDecAttention/k_slot_v/Initializer/zeros" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/k_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/k_slot_v/read" - op: "Identity" - input: "decoder/block_000/layer_001/EncDecAttention/k_slot_v" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/k_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/k/adafactor/maximum/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0010000000474974513 - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/k/adafactor/sub/x" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/k/adafactor/sub" - op: "Sub" - input: "decoder/block_000/layer_001/EncDecAttention/k/adafactor/sub/x" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/k/adafactor/maximum_1/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/v_slot_v/Initializer/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/v_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200\000\000\000" - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/v_slot_v/Initializer/zeros/Const" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/v_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/v_slot_v/Initializer/zeros" - op: "Fill" - input: "decoder/block_000/layer_001/EncDecAttention/v_slot_v/Initializer/zeros/shape_as_tensor" - input: "decoder/block_000/layer_001/EncDecAttention/v_slot_v/Initializer/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/v_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/v_slot_v" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/v_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/v_slot_v/Assign" - op: "Assign" - input: "decoder/block_000/layer_001/EncDecAttention/v_slot_v" - input: "decoder/block_000/layer_001/EncDecAttention/v_slot_v/Initializer/zeros" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/v_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/v_slot_v/read" - op: "Identity" - input: "decoder/block_000/layer_001/EncDecAttention/v_slot_v" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/v_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/v/adafactor/maximum/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0010000000474974513 - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/v/adafactor/sub/x" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/v/adafactor/sub" - op: "Sub" - input: "decoder/block_000/layer_001/EncDecAttention/v/adafactor/sub/x" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/v/adafactor/maximum_1/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/o_slot_v/Initializer/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/o_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\200\000\000\000 \000\000\000" - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/o_slot_v/Initializer/zeros/Const" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/o_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/o_slot_v/Initializer/zeros" - op: "Fill" - input: "decoder/block_000/layer_001/EncDecAttention/o_slot_v/Initializer/zeros/shape_as_tensor" - input: "decoder/block_000/layer_001/EncDecAttention/o_slot_v/Initializer/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/o_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/o_slot_v" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/o_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/o_slot_v/Assign" - op: "Assign" - input: "decoder/block_000/layer_001/EncDecAttention/o_slot_v" - input: "decoder/block_000/layer_001/EncDecAttention/o_slot_v/Initializer/zeros" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/o_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/o_slot_v/read" - op: "Identity" - input: "decoder/block_000/layer_001/EncDecAttention/o_slot_v" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/o_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/o/adafactor/maximum/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0010000000474974513 - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/o/adafactor/sub/x" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/o/adafactor/sub" - op: "Sub" - input: "decoder/block_000/layer_001/EncDecAttention/o/adafactor/sub/x" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/o/adafactor/maximum_1/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_000/layer_002/rms_norm/scale_slot_v/Initializer/zeros" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - dim { - size: 32 - } - } - float_val: 0.0 - } - } - } -} -node { - name: "decoder/block_000/layer_002/rms_norm/scale_slot_v" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_000/layer_002/rms_norm/scale_slot_v/Assign" - op: "Assign" - input: "decoder/block_000/layer_002/rms_norm/scale_slot_v" - input: "decoder/block_000/layer_002/rms_norm/scale_slot_v/Initializer/zeros" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_002/rms_norm/scale_slot_v/read" - op: "Identity" - input: "decoder/block_000/layer_002/rms_norm/scale_slot_v" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/rms_norm/scale/adafactor/maximum/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0010000000474974513 - } - } - } -} -node { - name: "decoder/block_000/layer_002/rms_norm/scale/adafactor/sub/x" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_000/layer_002/rms_norm/scale/adafactor/sub" - op: "Sub" - input: "decoder/block_000/layer_002/rms_norm/scale/adafactor/sub/x" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/rms_norm/scale/adafactor/maximum_1/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slot_v/Initializer/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000@\000\000\000" - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slot_v/Initializer/zeros/Const" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slot_v/Initializer/zeros" - op: "Fill" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slot_v/Initializer/zeros/shape_as_tensor" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slot_v/Initializer/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slot_v" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slot_v/Assign" - op: "Assign" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slot_v" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slot_v/Initializer/zeros" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slot_v/read" - op: "Identity" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slot_v" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/adafactor/maximum/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0010000000474974513 - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/adafactor/sub/x" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/adafactor/sub" - op: "Sub" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/adafactor/sub/x" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/adafactor/maximum_1/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slot_v/Initializer/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000@\000\000\000" - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slot_v/Initializer/zeros/Const" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slot_v/Initializer/zeros" - op: "Fill" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slot_v/Initializer/zeros/shape_as_tensor" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slot_v/Initializer/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slot_v" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slot_v/Assign" - op: "Assign" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slot_v" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slot_v/Initializer/zeros" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slot_v/read" - op: "Identity" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slot_v" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/adafactor/maximum/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0010000000474974513 - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/adafactor/sub/x" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/adafactor/sub" - op: "Sub" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/adafactor/sub/x" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/adafactor/maximum_1/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wo/kernel_slot_v/Initializer/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wo/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "@\000\000\000 \000\000\000" - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wo/kernel_slot_v/Initializer/zeros/Const" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wo/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wo/kernel_slot_v/Initializer/zeros" - op: "Fill" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel_slot_v/Initializer/zeros/shape_as_tensor" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel_slot_v/Initializer/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wo/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wo/kernel_slot_v" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wo/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wo/kernel_slot_v/Assign" - op: "Assign" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel_slot_v" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel_slot_v/Initializer/zeros" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wo/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wo/kernel_slot_v/read" - op: "Identity" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel_slot_v" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wo/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/adafactor/maximum/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0010000000474974513 - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/adafactor/sub/x" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/adafactor/sub" - op: "Sub" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/adafactor/sub/x" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/adafactor/maximum_1/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_001/layer_000/rms_norm/scale_slot_v/Initializer/zeros" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - dim { - size: 32 - } - } - float_val: 0.0 - } - } - } -} -node { - name: "decoder/block_001/layer_000/rms_norm/scale_slot_v" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_001/layer_000/rms_norm/scale_slot_v/Assign" - op: "Assign" - input: "decoder/block_001/layer_000/rms_norm/scale_slot_v" - input: "decoder/block_001/layer_000/rms_norm/scale_slot_v/Initializer/zeros" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_001/layer_000/rms_norm/scale_slot_v/read" - op: "Identity" - input: "decoder/block_001/layer_000/rms_norm/scale_slot_v" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/rms_norm/scale/adafactor/maximum/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0010000000474974513 - } - } - } -} -node { - name: "decoder/block_001/layer_000/rms_norm/scale/adafactor/sub/x" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_001/layer_000/rms_norm/scale/adafactor/sub" - op: "Sub" - input: "decoder/block_001/layer_000/rms_norm/scale/adafactor/sub/x" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/rms_norm/scale/adafactor/maximum_1/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/q_slot_v/Initializer/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/q_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200\000\000\000" - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/q_slot_v/Initializer/zeros/Const" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/q_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/q_slot_v/Initializer/zeros" - op: "Fill" - input: "decoder/block_001/layer_000/SelfAttention/q_slot_v/Initializer/zeros/shape_as_tensor" - input: "decoder/block_001/layer_000/SelfAttention/q_slot_v/Initializer/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/q_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/q_slot_v" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/q_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/q_slot_v/Assign" - op: "Assign" - input: "decoder/block_001/layer_000/SelfAttention/q_slot_v" - input: "decoder/block_001/layer_000/SelfAttention/q_slot_v/Initializer/zeros" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/q_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/q_slot_v/read" - op: "Identity" - input: "decoder/block_001/layer_000/SelfAttention/q_slot_v" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/q_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/q/adafactor/maximum/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0010000000474974513 - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/q/adafactor/sub/x" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/q/adafactor/sub" - op: "Sub" - input: "decoder/block_001/layer_000/SelfAttention/q/adafactor/sub/x" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/q/adafactor/maximum_1/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/k_slot_v/Initializer/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/k_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200\000\000\000" - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/k_slot_v/Initializer/zeros/Const" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/k_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/k_slot_v/Initializer/zeros" - op: "Fill" - input: "decoder/block_001/layer_000/SelfAttention/k_slot_v/Initializer/zeros/shape_as_tensor" - input: "decoder/block_001/layer_000/SelfAttention/k_slot_v/Initializer/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/k_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/k_slot_v" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/k_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/k_slot_v/Assign" - op: "Assign" - input: "decoder/block_001/layer_000/SelfAttention/k_slot_v" - input: "decoder/block_001/layer_000/SelfAttention/k_slot_v/Initializer/zeros" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/k_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/k_slot_v/read" - op: "Identity" - input: "decoder/block_001/layer_000/SelfAttention/k_slot_v" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/k_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/k/adafactor/maximum/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0010000000474974513 - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/k/adafactor/sub/x" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/k/adafactor/sub" - op: "Sub" - input: "decoder/block_001/layer_000/SelfAttention/k/adafactor/sub/x" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/k/adafactor/maximum_1/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/v_slot_v/Initializer/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/v_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200\000\000\000" - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/v_slot_v/Initializer/zeros/Const" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/v_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/v_slot_v/Initializer/zeros" - op: "Fill" - input: "decoder/block_001/layer_000/SelfAttention/v_slot_v/Initializer/zeros/shape_as_tensor" - input: "decoder/block_001/layer_000/SelfAttention/v_slot_v/Initializer/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/v_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/v_slot_v" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/v_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/v_slot_v/Assign" - op: "Assign" - input: "decoder/block_001/layer_000/SelfAttention/v_slot_v" - input: "decoder/block_001/layer_000/SelfAttention/v_slot_v/Initializer/zeros" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/v_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/v_slot_v/read" - op: "Identity" - input: "decoder/block_001/layer_000/SelfAttention/v_slot_v" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/v_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/v/adafactor/maximum/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0010000000474974513 - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/v/adafactor/sub/x" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/v/adafactor/sub" - op: "Sub" - input: "decoder/block_001/layer_000/SelfAttention/v/adafactor/sub/x" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/v/adafactor/maximum_1/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/o_slot_v/Initializer/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/o_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\200\000\000\000 \000\000\000" - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/o_slot_v/Initializer/zeros/Const" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/o_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/o_slot_v/Initializer/zeros" - op: "Fill" - input: "decoder/block_001/layer_000/SelfAttention/o_slot_v/Initializer/zeros/shape_as_tensor" - input: "decoder/block_001/layer_000/SelfAttention/o_slot_v/Initializer/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/o_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/o_slot_v" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/o_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/o_slot_v/Assign" - op: "Assign" - input: "decoder/block_001/layer_000/SelfAttention/o_slot_v" - input: "decoder/block_001/layer_000/SelfAttention/o_slot_v/Initializer/zeros" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/o_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/o_slot_v/read" - op: "Identity" - input: "decoder/block_001/layer_000/SelfAttention/o_slot_v" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/o_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/o/adafactor/maximum/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0010000000474974513 - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/o/adafactor/sub/x" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/o/adafactor/sub" - op: "Sub" - input: "decoder/block_001/layer_000/SelfAttention/o/adafactor/sub/x" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/o/adafactor/maximum_1/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_001/layer_001/rms_norm/scale_slot_v/Initializer/zeros" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - dim { - size: 32 - } - } - float_val: 0.0 - } - } - } -} -node { - name: "decoder/block_001/layer_001/rms_norm/scale_slot_v" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_001/layer_001/rms_norm/scale_slot_v/Assign" - op: "Assign" - input: "decoder/block_001/layer_001/rms_norm/scale_slot_v" - input: "decoder/block_001/layer_001/rms_norm/scale_slot_v/Initializer/zeros" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_001/layer_001/rms_norm/scale_slot_v/read" - op: "Identity" - input: "decoder/block_001/layer_001/rms_norm/scale_slot_v" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/rms_norm/scale/adafactor/maximum/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0010000000474974513 - } - } - } -} -node { - name: "decoder/block_001/layer_001/rms_norm/scale/adafactor/sub/x" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_001/layer_001/rms_norm/scale/adafactor/sub" - op: "Sub" - input: "decoder/block_001/layer_001/rms_norm/scale/adafactor/sub/x" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/rms_norm/scale/adafactor/maximum_1/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/q_slot_v/Initializer/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/q_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200\000\000\000" - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/q_slot_v/Initializer/zeros/Const" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/q_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/q_slot_v/Initializer/zeros" - op: "Fill" - input: "decoder/block_001/layer_001/EncDecAttention/q_slot_v/Initializer/zeros/shape_as_tensor" - input: "decoder/block_001/layer_001/EncDecAttention/q_slot_v/Initializer/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/q_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/q_slot_v" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/q_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/q_slot_v/Assign" - op: "Assign" - input: "decoder/block_001/layer_001/EncDecAttention/q_slot_v" - input: "decoder/block_001/layer_001/EncDecAttention/q_slot_v/Initializer/zeros" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/q_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/q_slot_v/read" - op: "Identity" - input: "decoder/block_001/layer_001/EncDecAttention/q_slot_v" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/q_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/q/adafactor/maximum/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0010000000474974513 - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/q/adafactor/sub/x" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/q/adafactor/sub" - op: "Sub" - input: "decoder/block_001/layer_001/EncDecAttention/q/adafactor/sub/x" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/q/adafactor/maximum_1/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/k_slot_v/Initializer/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/k_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200\000\000\000" - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/k_slot_v/Initializer/zeros/Const" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/k_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/k_slot_v/Initializer/zeros" - op: "Fill" - input: "decoder/block_001/layer_001/EncDecAttention/k_slot_v/Initializer/zeros/shape_as_tensor" - input: "decoder/block_001/layer_001/EncDecAttention/k_slot_v/Initializer/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/k_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/k_slot_v" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/k_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/k_slot_v/Assign" - op: "Assign" - input: "decoder/block_001/layer_001/EncDecAttention/k_slot_v" - input: "decoder/block_001/layer_001/EncDecAttention/k_slot_v/Initializer/zeros" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/k_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/k_slot_v/read" - op: "Identity" - input: "decoder/block_001/layer_001/EncDecAttention/k_slot_v" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/k_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/k/adafactor/maximum/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0010000000474974513 - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/k/adafactor/sub/x" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/k/adafactor/sub" - op: "Sub" - input: "decoder/block_001/layer_001/EncDecAttention/k/adafactor/sub/x" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/k/adafactor/maximum_1/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/v_slot_v/Initializer/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/v_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200\000\000\000" - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/v_slot_v/Initializer/zeros/Const" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/v_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/v_slot_v/Initializer/zeros" - op: "Fill" - input: "decoder/block_001/layer_001/EncDecAttention/v_slot_v/Initializer/zeros/shape_as_tensor" - input: "decoder/block_001/layer_001/EncDecAttention/v_slot_v/Initializer/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/v_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/v_slot_v" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/v_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/v_slot_v/Assign" - op: "Assign" - input: "decoder/block_001/layer_001/EncDecAttention/v_slot_v" - input: "decoder/block_001/layer_001/EncDecAttention/v_slot_v/Initializer/zeros" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/v_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/v_slot_v/read" - op: "Identity" - input: "decoder/block_001/layer_001/EncDecAttention/v_slot_v" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/v_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/v/adafactor/maximum/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0010000000474974513 - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/v/adafactor/sub/x" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/v/adafactor/sub" - op: "Sub" - input: "decoder/block_001/layer_001/EncDecAttention/v/adafactor/sub/x" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/v/adafactor/maximum_1/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/o_slot_v/Initializer/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/o_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\200\000\000\000 \000\000\000" - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/o_slot_v/Initializer/zeros/Const" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/o_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/o_slot_v/Initializer/zeros" - op: "Fill" - input: "decoder/block_001/layer_001/EncDecAttention/o_slot_v/Initializer/zeros/shape_as_tensor" - input: "decoder/block_001/layer_001/EncDecAttention/o_slot_v/Initializer/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/o_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/o_slot_v" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/o_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/o_slot_v/Assign" - op: "Assign" - input: "decoder/block_001/layer_001/EncDecAttention/o_slot_v" - input: "decoder/block_001/layer_001/EncDecAttention/o_slot_v/Initializer/zeros" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/o_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/o_slot_v/read" - op: "Identity" - input: "decoder/block_001/layer_001/EncDecAttention/o_slot_v" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/o_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/o/adafactor/maximum/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0010000000474974513 - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/o/adafactor/sub/x" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/o/adafactor/sub" - op: "Sub" - input: "decoder/block_001/layer_001/EncDecAttention/o/adafactor/sub/x" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/o/adafactor/maximum_1/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_001/layer_002/rms_norm/scale_slot_v/Initializer/zeros" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - dim { - size: 32 - } - } - float_val: 0.0 - } - } - } -} -node { - name: "decoder/block_001/layer_002/rms_norm/scale_slot_v" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_001/layer_002/rms_norm/scale_slot_v/Assign" - op: "Assign" - input: "decoder/block_001/layer_002/rms_norm/scale_slot_v" - input: "decoder/block_001/layer_002/rms_norm/scale_slot_v/Initializer/zeros" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_001/layer_002/rms_norm/scale_slot_v/read" - op: "Identity" - input: "decoder/block_001/layer_002/rms_norm/scale_slot_v" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/rms_norm/scale/adafactor/maximum/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0010000000474974513 - } - } - } -} -node { - name: "decoder/block_001/layer_002/rms_norm/scale/adafactor/sub/x" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_001/layer_002/rms_norm/scale/adafactor/sub" - op: "Sub" - input: "decoder/block_001/layer_002/rms_norm/scale/adafactor/sub/x" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/rms_norm/scale/adafactor/maximum_1/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slot_v/Initializer/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000@\000\000\000" - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slot_v/Initializer/zeros/Const" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slot_v/Initializer/zeros" - op: "Fill" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slot_v/Initializer/zeros/shape_as_tensor" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slot_v/Initializer/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slot_v" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slot_v/Assign" - op: "Assign" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slot_v" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slot_v/Initializer/zeros" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slot_v/read" - op: "Identity" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slot_v" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/adafactor/maximum/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0010000000474974513 - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/adafactor/sub/x" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/adafactor/sub" - op: "Sub" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/adafactor/sub/x" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/adafactor/maximum_1/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slot_v/Initializer/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000@\000\000\000" - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slot_v/Initializer/zeros/Const" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slot_v/Initializer/zeros" - op: "Fill" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slot_v/Initializer/zeros/shape_as_tensor" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slot_v/Initializer/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slot_v" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slot_v/Assign" - op: "Assign" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slot_v" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slot_v/Initializer/zeros" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slot_v/read" - op: "Identity" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slot_v" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/adafactor/maximum/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0010000000474974513 - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/adafactor/sub/x" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/adafactor/sub" - op: "Sub" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/adafactor/sub/x" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/adafactor/maximum_1/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wo/kernel_slot_v/Initializer/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wo/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "@\000\000\000 \000\000\000" - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wo/kernel_slot_v/Initializer/zeros/Const" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wo/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wo/kernel_slot_v/Initializer/zeros" - op: "Fill" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel_slot_v/Initializer/zeros/shape_as_tensor" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel_slot_v/Initializer/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wo/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wo/kernel_slot_v" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wo/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wo/kernel_slot_v/Assign" - op: "Assign" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel_slot_v" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel_slot_v/Initializer/zeros" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wo/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wo/kernel_slot_v/read" - op: "Identity" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel_slot_v" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wo/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/adafactor/maximum/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0010000000474974513 - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/adafactor/sub/x" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/adafactor/sub" - op: "Sub" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/adafactor/sub/x" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/adafactor/maximum_1/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/rms_norm/scale_slot_v/Initializer/zeros" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - dim { - size: 32 - } - } - float_val: 0.0 - } - } - } -} -node { - name: "decoder/rms_norm/scale_slot_v" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/rms_norm/scale_slot_v/Assign" - op: "Assign" - input: "decoder/rms_norm/scale_slot_v" - input: "decoder/rms_norm/scale_slot_v/Initializer/zeros" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/rms_norm/scale_slot_v/read" - op: "Identity" - input: "decoder/rms_norm/scale_slot_v" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/rms_norm/scale/adafactor/maximum/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0010000000474974513 - } - } - } -} -node { - name: "decoder/rms_norm/scale/adafactor/sub/x" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/rms_norm/scale/adafactor/sub" - op: "Sub" - input: "decoder/rms_norm/scale/adafactor/sub/x" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/rms_norm/scale/adafactor/maximum_1/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/logits/kernel_slot_v/Initializer/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/logits/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200}\000\000" - } - } - } -} -node { - name: "decoder/logits/kernel_slot_v/Initializer/zeros/Const" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/logits/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "decoder/logits/kernel_slot_v/Initializer/zeros" - op: "Fill" - input: "decoder/logits/kernel_slot_v/Initializer/zeros/shape_as_tensor" - input: "decoder/logits/kernel_slot_v/Initializer/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/logits/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 32128 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "decoder/logits/kernel_slot_v" - op: "VariableV2" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/logits/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 32128 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 32128 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/logits/kernel_slot_v/Assign" - op: "Assign" - input: "decoder/logits/kernel_slot_v" - input: "decoder/logits/kernel_slot_v/Initializer/zeros" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/logits/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 32128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/logits/kernel_slot_v/read" - op: "Identity" - input: "decoder/logits/kernel_slot_v" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/logits/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 32128 - } - } - } - } - } -} -node { - name: "decoder/logits/kernel/adafactor/maximum/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0010000000474974513 - } - } - } -} -node { - name: "decoder/logits/kernel/adafactor/sub/x" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/logits/kernel/adafactor/sub" - op: "Sub" - input: "decoder/logits/kernel/adafactor/sub/x" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/logits/kernel/adafactor/maximum_1/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "reshape_6/parallel_0/Reshape/shape" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\001\000\000\000 \000\000\000\000\002\000\000" - } - } - } -} -node { - name: "reshape_6/parallel_0/Reshape" - op: "Reshape" - input: "Reshape" - input: "reshape_6/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 32 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "reshape_1_1/parallel_0/Reshape/shape" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\001\000\000\000 \000\000\000\000\002\000\000" - } - } - } -} -node { - name: "reshape_1_1/parallel_0/Reshape" - op: "Reshape" - input: "Reshape_1" - input: "reshape_1_1/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 32 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "reshape_2_1/parallel_0/Reshape/shape" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\001\000\000\000 \000\000\000\000\002\000\000" - } - } - } -} -node { - name: "reshape_2_1/parallel_0/Reshape" - op: "Reshape" - input: "Reshape_2" - input: "reshape_2_1/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 32 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "reshape_3_1/parallel_0/Reshape/shape" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\001\000\000\000 \000\000\000\000\002\000\000" - } - } - } -} -node { - name: "reshape_3_1/parallel_0/Reshape" - op: "Reshape" - input: "Reshape_3" - input: "reshape_3_1/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 32 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "reshape_4_1/parallel_0/Reshape/shape" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\001\000\000\000 \000\000\000\000\002\000\000" - } - } - } -} -node { - name: "reshape_4_1/parallel_0/Reshape" - op: "Reshape" - input: "Reshape_4" - input: "reshape_4_1/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 32 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "reshape_5_1/parallel_0/Reshape/shape" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\001\000\000\000 \000\000\000\000\002\000\000" - } - } - } -} -node { - name: "reshape_5_1/parallel_0/Reshape" - op: "Reshape" - input: "Reshape_5" - input: "reshape_5_1/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 32 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "constant/parallel_0/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "shared/embedding_slice_0/Initializer/random_uniform/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@shared/embedding_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\200}\000\000 \000\000\000" - } - } - } -} -node { - name: "shared/embedding_slice_0/Initializer/random_uniform/min" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@shared/embedding_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: -0.013658959418535233 - } - } - } -} -node { - name: "shared/embedding_slice_0/Initializer/random_uniform/max" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@shared/embedding_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.013658959418535233 - } - } - } -} -node { - name: "shared/embedding_slice_0/Initializer/random_uniform/RandomUniform" - op: "RandomUniform" - input: "shared/embedding_slice_0/Initializer/random_uniform/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@shared/embedding_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "shared/embedding_slice_0/Initializer/random_uniform/sub" - op: "Sub" - input: "shared/embedding_slice_0/Initializer/random_uniform/max" - input: "shared/embedding_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@shared/embedding_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "shared/embedding_slice_0/Initializer/random_uniform/mul" - op: "Mul" - input: "shared/embedding_slice_0/Initializer/random_uniform/RandomUniform" - input: "shared/embedding_slice_0/Initializer/random_uniform/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@shared/embedding_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "shared/embedding_slice_0/Initializer/random_uniform" - op: "Add" - input: "shared/embedding_slice_0/Initializer/random_uniform/mul" - input: "shared/embedding_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@shared/embedding_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "shared/embedding_slice_0" - op: "VariableV2" - attr { - key: "_class" - value { - list { - s: "loc:@shared/embedding_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32128 - } - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "shared/embedding_slice_0/Assign" - op: "Assign" - input: "shared/embedding_slice_0" - input: "shared/embedding_slice_0/Initializer/random_uniform" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@shared/embedding_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "shared/embedding_slice_0/read" - op: "Identity" - input: "shared/embedding_slice_0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@shared/embedding_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "shared/embedding_1/Cast" - op: "Cast" - input: "shared/embedding_slice_0/read" - attr { - key: "DstT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "SrcT" - value { - type: DT_FLOAT - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "shared/embedding_1/parallel_0_1/Cast" - op: "Cast" - input: "shared/embedding/read" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "shared/embedding_1/parallel_0_1/Assign" - op: "Assign" - input: "shared/embedding_slice_0" - input: "shared/embedding_1/parallel_0_1/Cast" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@shared/embedding_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "shared/embedding_1/group_deps" - op: "NoOp" - input: "^shared/embedding_1/parallel_0_1/Assign" -} -node { - name: "shared/embedding_1/Assign" - op: "Assign" - input: "shared/embedding" - input: "shared/embedding_1/Cast" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@shared/embedding" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_000/layer_000/rms_norm/scale_slice_0/Initializer/random_uniform/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 32 - } - } - } -} -node { - name: "encoder/block_000/layer_000/rms_norm/scale_slice_0/Initializer/random_uniform/min" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: -0.3061862289905548 - } - } - } -} -node { - name: "encoder/block_000/layer_000/rms_norm/scale_slice_0/Initializer/random_uniform/max" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.3061862289905548 - } - } - } -} -node { - name: "encoder/block_000/layer_000/rms_norm/scale_slice_0/Initializer/random_uniform/RandomUniform" - op: "RandomUniform" - input: "encoder/block_000/layer_000/rms_norm/scale_slice_0/Initializer/random_uniform/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "encoder/block_000/layer_000/rms_norm/scale_slice_0/Initializer/random_uniform/sub" - op: "Sub" - input: "encoder/block_000/layer_000/rms_norm/scale_slice_0/Initializer/random_uniform/max" - input: "encoder/block_000/layer_000/rms_norm/scale_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/rms_norm/scale_slice_0/Initializer/random_uniform/mul" - op: "Mul" - input: "encoder/block_000/layer_000/rms_norm/scale_slice_0/Initializer/random_uniform/RandomUniform" - input: "encoder/block_000/layer_000/rms_norm/scale_slice_0/Initializer/random_uniform/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/rms_norm/scale_slice_0/Initializer/random_uniform" - op: "Add" - input: "encoder/block_000/layer_000/rms_norm/scale_slice_0/Initializer/random_uniform/mul" - input: "encoder/block_000/layer_000/rms_norm/scale_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/rms_norm/scale_slice_0" - op: "VariableV2" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "encoder/block_000/layer_000/rms_norm/scale_slice_0/Assign" - op: "Assign" - input: "encoder/block_000/layer_000/rms_norm/scale_slice_0" - input: "encoder/block_000/layer_000/rms_norm/scale_slice_0/Initializer/random_uniform" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_000/layer_000/rms_norm/scale_slice_0/read" - op: "Identity" - input: "encoder/block_000/layer_000/rms_norm/scale_slice_0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/rms_norm/scale_1/Cast" - op: "Cast" - input: "encoder/block_000/layer_000/rms_norm/scale_slice_0/read" - attr { - key: "DstT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "SrcT" - value { - type: DT_FLOAT - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/rms_norm/scale_1/parallel_0_1/Cast" - op: "Cast" - input: "encoder/block_000/layer_000/rms_norm/scale/read" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/rms_norm/scale_1/parallel_0_1/Assign" - op: "Assign" - input: "encoder/block_000/layer_000/rms_norm/scale_slice_0" - input: "encoder/block_000/layer_000/rms_norm/scale_1/parallel_0_1/Cast" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_000/layer_000/rms_norm/scale_1/group_deps" - op: "NoOp" - input: "^encoder/block_000/layer_000/rms_norm/scale_1/parallel_0_1/Assign" -} -node { - name: "encoder/block_000/layer_000/rms_norm/scale_1/Assign" - op: "Assign" - input: "encoder/block_000/layer_000/rms_norm/scale" - input: "encoder/block_000/layer_000/rms_norm/scale_1/Cast" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/q_slice_0/Initializer/random_uniform/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/q_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200\000\000\000" - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/q_slice_0/Initializer/random_uniform/min" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/q_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: -0.19364917278289795 - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/q_slice_0/Initializer/random_uniform/max" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/q_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.19364917278289795 - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/q_slice_0/Initializer/random_uniform/RandomUniform" - op: "RandomUniform" - input: "encoder/block_000/layer_000/SelfAttention/q_slice_0/Initializer/random_uniform/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/q_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/q_slice_0/Initializer/random_uniform/sub" - op: "Sub" - input: "encoder/block_000/layer_000/SelfAttention/q_slice_0/Initializer/random_uniform/max" - input: "encoder/block_000/layer_000/SelfAttention/q_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/q_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/q_slice_0/Initializer/random_uniform/mul" - op: "Mul" - input: "encoder/block_000/layer_000/SelfAttention/q_slice_0/Initializer/random_uniform/RandomUniform" - input: "encoder/block_000/layer_000/SelfAttention/q_slice_0/Initializer/random_uniform/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/q_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/q_slice_0/Initializer/random_uniform" - op: "Add" - input: "encoder/block_000/layer_000/SelfAttention/q_slice_0/Initializer/random_uniform/mul" - input: "encoder/block_000/layer_000/SelfAttention/q_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/q_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/q_slice_0" - op: "VariableV2" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/q_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/q_slice_0/Assign" - op: "Assign" - input: "encoder/block_000/layer_000/SelfAttention/q_slice_0" - input: "encoder/block_000/layer_000/SelfAttention/q_slice_0/Initializer/random_uniform" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/q_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/q_slice_0/read" - op: "Identity" - input: "encoder/block_000/layer_000/SelfAttention/q_slice_0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/q_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/q_1/Cast" - op: "Cast" - input: "encoder/block_000/layer_000/SelfAttention/q_slice_0/read" - attr { - key: "DstT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "SrcT" - value { - type: DT_FLOAT - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/q_1/parallel_0_1/Cast" - op: "Cast" - input: "encoder/block_000/layer_000/SelfAttention/q/read" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/q_1/parallel_0_1/Assign" - op: "Assign" - input: "encoder/block_000/layer_000/SelfAttention/q_slice_0" - input: "encoder/block_000/layer_000/SelfAttention/q_1/parallel_0_1/Cast" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/q_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/q_1/group_deps" - op: "NoOp" - input: "^encoder/block_000/layer_000/SelfAttention/q_1/parallel_0_1/Assign" -} -node { - name: "encoder/block_000/layer_000/SelfAttention/q_1/Assign" - op: "Assign" - input: "encoder/block_000/layer_000/SelfAttention/q" - input: "encoder/block_000/layer_000/SelfAttention/q_1/Cast" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/k_slice_0/Initializer/random_uniform/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/k_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200\000\000\000" - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/k_slice_0/Initializer/random_uniform/min" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/k_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: -0.19364917278289795 - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/k_slice_0/Initializer/random_uniform/max" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/k_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.19364917278289795 - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/k_slice_0/Initializer/random_uniform/RandomUniform" - op: "RandomUniform" - input: "encoder/block_000/layer_000/SelfAttention/k_slice_0/Initializer/random_uniform/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/k_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/k_slice_0/Initializer/random_uniform/sub" - op: "Sub" - input: "encoder/block_000/layer_000/SelfAttention/k_slice_0/Initializer/random_uniform/max" - input: "encoder/block_000/layer_000/SelfAttention/k_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/k_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/k_slice_0/Initializer/random_uniform/mul" - op: "Mul" - input: "encoder/block_000/layer_000/SelfAttention/k_slice_0/Initializer/random_uniform/RandomUniform" - input: "encoder/block_000/layer_000/SelfAttention/k_slice_0/Initializer/random_uniform/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/k_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/k_slice_0/Initializer/random_uniform" - op: "Add" - input: "encoder/block_000/layer_000/SelfAttention/k_slice_0/Initializer/random_uniform/mul" - input: "encoder/block_000/layer_000/SelfAttention/k_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/k_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/k_slice_0" - op: "VariableV2" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/k_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/k_slice_0/Assign" - op: "Assign" - input: "encoder/block_000/layer_000/SelfAttention/k_slice_0" - input: "encoder/block_000/layer_000/SelfAttention/k_slice_0/Initializer/random_uniform" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/k_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/k_slice_0/read" - op: "Identity" - input: "encoder/block_000/layer_000/SelfAttention/k_slice_0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/k_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/k_1/Cast" - op: "Cast" - input: "encoder/block_000/layer_000/SelfAttention/k_slice_0/read" - attr { - key: "DstT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "SrcT" - value { - type: DT_FLOAT - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/k_1/parallel_0_1/Cast" - op: "Cast" - input: "encoder/block_000/layer_000/SelfAttention/k/read" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/k_1/parallel_0_1/Assign" - op: "Assign" - input: "encoder/block_000/layer_000/SelfAttention/k_slice_0" - input: "encoder/block_000/layer_000/SelfAttention/k_1/parallel_0_1/Cast" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/k_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/k_1/group_deps" - op: "NoOp" - input: "^encoder/block_000/layer_000/SelfAttention/k_1/parallel_0_1/Assign" -} -node { - name: "encoder/block_000/layer_000/SelfAttention/k_1/Assign" - op: "Assign" - input: "encoder/block_000/layer_000/SelfAttention/k" - input: "encoder/block_000/layer_000/SelfAttention/k_1/Cast" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/v_slice_0/Initializer/random_uniform/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/v_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200\000\000\000" - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/v_slice_0/Initializer/random_uniform/min" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/v_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: -0.19364917278289795 - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/v_slice_0/Initializer/random_uniform/max" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/v_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.19364917278289795 - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/v_slice_0/Initializer/random_uniform/RandomUniform" - op: "RandomUniform" - input: "encoder/block_000/layer_000/SelfAttention/v_slice_0/Initializer/random_uniform/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/v_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/v_slice_0/Initializer/random_uniform/sub" - op: "Sub" - input: "encoder/block_000/layer_000/SelfAttention/v_slice_0/Initializer/random_uniform/max" - input: "encoder/block_000/layer_000/SelfAttention/v_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/v_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/v_slice_0/Initializer/random_uniform/mul" - op: "Mul" - input: "encoder/block_000/layer_000/SelfAttention/v_slice_0/Initializer/random_uniform/RandomUniform" - input: "encoder/block_000/layer_000/SelfAttention/v_slice_0/Initializer/random_uniform/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/v_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/v_slice_0/Initializer/random_uniform" - op: "Add" - input: "encoder/block_000/layer_000/SelfAttention/v_slice_0/Initializer/random_uniform/mul" - input: "encoder/block_000/layer_000/SelfAttention/v_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/v_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/v_slice_0" - op: "VariableV2" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/v_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/v_slice_0/Assign" - op: "Assign" - input: "encoder/block_000/layer_000/SelfAttention/v_slice_0" - input: "encoder/block_000/layer_000/SelfAttention/v_slice_0/Initializer/random_uniform" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/v_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/v_slice_0/read" - op: "Identity" - input: "encoder/block_000/layer_000/SelfAttention/v_slice_0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/v_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/v_1/Cast" - op: "Cast" - input: "encoder/block_000/layer_000/SelfAttention/v_slice_0/read" - attr { - key: "DstT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "SrcT" - value { - type: DT_FLOAT - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/v_1/parallel_0_1/Cast" - op: "Cast" - input: "encoder/block_000/layer_000/SelfAttention/v/read" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/v_1/parallel_0_1/Assign" - op: "Assign" - input: "encoder/block_000/layer_000/SelfAttention/v_slice_0" - input: "encoder/block_000/layer_000/SelfAttention/v_1/parallel_0_1/Cast" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/v_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/v_1/group_deps" - op: "NoOp" - input: "^encoder/block_000/layer_000/SelfAttention/v_1/parallel_0_1/Assign" -} -node { - name: "encoder/block_000/layer_000/SelfAttention/v_1/Assign" - op: "Assign" - input: "encoder/block_000/layer_000/SelfAttention/v" - input: "encoder/block_000/layer_000/SelfAttention/v_1/Cast" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/o_slice_0/Initializer/random_uniform/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/o_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\200\000\000\000 \000\000\000" - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/o_slice_0/Initializer/random_uniform/min" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/o_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: -0.19364917278289795 - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/o_slice_0/Initializer/random_uniform/max" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/o_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.19364917278289795 - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/o_slice_0/Initializer/random_uniform/RandomUniform" - op: "RandomUniform" - input: "encoder/block_000/layer_000/SelfAttention/o_slice_0/Initializer/random_uniform/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/o_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/o_slice_0/Initializer/random_uniform/sub" - op: "Sub" - input: "encoder/block_000/layer_000/SelfAttention/o_slice_0/Initializer/random_uniform/max" - input: "encoder/block_000/layer_000/SelfAttention/o_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/o_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/o_slice_0/Initializer/random_uniform/mul" - op: "Mul" - input: "encoder/block_000/layer_000/SelfAttention/o_slice_0/Initializer/random_uniform/RandomUniform" - input: "encoder/block_000/layer_000/SelfAttention/o_slice_0/Initializer/random_uniform/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/o_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/o_slice_0/Initializer/random_uniform" - op: "Add" - input: "encoder/block_000/layer_000/SelfAttention/o_slice_0/Initializer/random_uniform/mul" - input: "encoder/block_000/layer_000/SelfAttention/o_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/o_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/o_slice_0" - op: "VariableV2" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/o_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/o_slice_0/Assign" - op: "Assign" - input: "encoder/block_000/layer_000/SelfAttention/o_slice_0" - input: "encoder/block_000/layer_000/SelfAttention/o_slice_0/Initializer/random_uniform" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/o_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/o_slice_0/read" - op: "Identity" - input: "encoder/block_000/layer_000/SelfAttention/o_slice_0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/o_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/o_1/Cast" - op: "Cast" - input: "encoder/block_000/layer_000/SelfAttention/o_slice_0/read" - attr { - key: "DstT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "SrcT" - value { - type: DT_FLOAT - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/o_1/parallel_0_1/Cast" - op: "Cast" - input: "encoder/block_000/layer_000/SelfAttention/o/read" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/o_1/parallel_0_1/Assign" - op: "Assign" - input: "encoder/block_000/layer_000/SelfAttention/o_slice_0" - input: "encoder/block_000/layer_000/SelfAttention/o_1/parallel_0_1/Cast" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/o_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/o_1/group_deps" - op: "NoOp" - input: "^encoder/block_000/layer_000/SelfAttention/o_1/parallel_0_1/Assign" -} -node { - name: "encoder/block_000/layer_000/SelfAttention/o_1/Assign" - op: "Assign" - input: "encoder/block_000/layer_000/SelfAttention/o" - input: "encoder/block_000/layer_000/SelfAttention/o_1/Cast" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0/Initializer/random_uniform/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\002\000\000\000 \000\000\000" - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0/Initializer/random_uniform/min" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: -0.42008402943611145 - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0/Initializer/random_uniform/max" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.42008402943611145 - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0/Initializer/random_uniform/RandomUniform" - op: "RandomUniform" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0/Initializer/random_uniform/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0/Initializer/random_uniform/sub" - op: "Sub" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0/Initializer/random_uniform/max" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0/Initializer/random_uniform/mul" - op: "Mul" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0/Initializer/random_uniform/RandomUniform" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0/Initializer/random_uniform/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0/Initializer/random_uniform" - op: "Add" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0/Initializer/random_uniform/mul" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0" - op: "VariableV2" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0/Assign" - op: "Assign" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0/Initializer/random_uniform" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0/read" - op: "Identity" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias_1/Cast" - op: "Cast" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0/read" - attr { - key: "DstT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "SrcT" - value { - type: DT_FLOAT - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias_1/parallel_0_1/Cast" - op: "Cast" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/read" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias_1/parallel_0_1/Assign" - op: "Assign" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias_1/parallel_0_1/Cast" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias_1/group_deps" - op: "NoOp" - input: "^encoder/block_000/layer_000/SelfAttention/relative_attention_bias_1/parallel_0_1/Assign" -} -node { - name: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias_1/Assign" - op: "Assign" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias_1/Cast" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/relative_attention_bias" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_000/layer_001/rms_norm/scale_slice_0/Initializer/random_uniform/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 32 - } - } - } -} -node { - name: "encoder/block_000/layer_001/rms_norm/scale_slice_0/Initializer/random_uniform/min" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: -0.3061862289905548 - } - } - } -} -node { - name: "encoder/block_000/layer_001/rms_norm/scale_slice_0/Initializer/random_uniform/max" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.3061862289905548 - } - } - } -} -node { - name: "encoder/block_000/layer_001/rms_norm/scale_slice_0/Initializer/random_uniform/RandomUniform" - op: "RandomUniform" - input: "encoder/block_000/layer_001/rms_norm/scale_slice_0/Initializer/random_uniform/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "encoder/block_000/layer_001/rms_norm/scale_slice_0/Initializer/random_uniform/sub" - op: "Sub" - input: "encoder/block_000/layer_001/rms_norm/scale_slice_0/Initializer/random_uniform/max" - input: "encoder/block_000/layer_001/rms_norm/scale_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/rms_norm/scale_slice_0/Initializer/random_uniform/mul" - op: "Mul" - input: "encoder/block_000/layer_001/rms_norm/scale_slice_0/Initializer/random_uniform/RandomUniform" - input: "encoder/block_000/layer_001/rms_norm/scale_slice_0/Initializer/random_uniform/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/rms_norm/scale_slice_0/Initializer/random_uniform" - op: "Add" - input: "encoder/block_000/layer_001/rms_norm/scale_slice_0/Initializer/random_uniform/mul" - input: "encoder/block_000/layer_001/rms_norm/scale_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/rms_norm/scale_slice_0" - op: "VariableV2" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "encoder/block_000/layer_001/rms_norm/scale_slice_0/Assign" - op: "Assign" - input: "encoder/block_000/layer_001/rms_norm/scale_slice_0" - input: "encoder/block_000/layer_001/rms_norm/scale_slice_0/Initializer/random_uniform" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_000/layer_001/rms_norm/scale_slice_0/read" - op: "Identity" - input: "encoder/block_000/layer_001/rms_norm/scale_slice_0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/rms_norm/scale_1/Cast" - op: "Cast" - input: "encoder/block_000/layer_001/rms_norm/scale_slice_0/read" - attr { - key: "DstT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "SrcT" - value { - type: DT_FLOAT - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/rms_norm/scale_1/parallel_0_1/Cast" - op: "Cast" - input: "encoder/block_000/layer_001/rms_norm/scale/read" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/rms_norm/scale_1/parallel_0_1/Assign" - op: "Assign" - input: "encoder/block_000/layer_001/rms_norm/scale_slice_0" - input: "encoder/block_000/layer_001/rms_norm/scale_1/parallel_0_1/Cast" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_000/layer_001/rms_norm/scale_1/group_deps" - op: "NoOp" - input: "^encoder/block_000/layer_001/rms_norm/scale_1/parallel_0_1/Assign" -} -node { - name: "encoder/block_000/layer_001/rms_norm/scale_1/Assign" - op: "Assign" - input: "encoder/block_000/layer_001/rms_norm/scale" - input: "encoder/block_000/layer_001/rms_norm/scale_1/Cast" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slice_0/Initializer/random_uniform/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000@\000\000\000" - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slice_0/Initializer/random_uniform/min" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: -0.25 - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slice_0/Initializer/random_uniform/max" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.25 - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slice_0/Initializer/random_uniform/RandomUniform" - op: "RandomUniform" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slice_0/Initializer/random_uniform/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slice_0/Initializer/random_uniform/sub" - op: "Sub" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slice_0/Initializer/random_uniform/max" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slice_0/Initializer/random_uniform/mul" - op: "Mul" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slice_0/Initializer/random_uniform/RandomUniform" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slice_0/Initializer/random_uniform/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slice_0/Initializer/random_uniform" - op: "Add" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slice_0/Initializer/random_uniform/mul" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slice_0" - op: "VariableV2" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slice_0/Assign" - op: "Assign" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slice_0" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slice_0/Initializer/random_uniform" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slice_0/read" - op: "Identity" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slice_0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_1/Cast" - op: "Cast" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slice_0/read" - attr { - key: "DstT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "SrcT" - value { - type: DT_FLOAT - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_1/parallel_0_1/Cast" - op: "Cast" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/read" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_1/parallel_0_1/Assign" - op: "Assign" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slice_0" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_1/parallel_0_1/Cast" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_1/group_deps" - op: "NoOp" - input: "^encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_1/parallel_0_1/Assign" -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_1/Assign" - op: "Assign" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_1/Cast" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wi_0/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slice_0/Initializer/random_uniform/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000@\000\000\000" - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slice_0/Initializer/random_uniform/min" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: -0.25 - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slice_0/Initializer/random_uniform/max" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.25 - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slice_0/Initializer/random_uniform/RandomUniform" - op: "RandomUniform" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slice_0/Initializer/random_uniform/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slice_0/Initializer/random_uniform/sub" - op: "Sub" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slice_0/Initializer/random_uniform/max" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slice_0/Initializer/random_uniform/mul" - op: "Mul" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slice_0/Initializer/random_uniform/RandomUniform" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slice_0/Initializer/random_uniform/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slice_0/Initializer/random_uniform" - op: "Add" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slice_0/Initializer/random_uniform/mul" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slice_0" - op: "VariableV2" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slice_0/Assign" - op: "Assign" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slice_0" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slice_0/Initializer/random_uniform" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slice_0/read" - op: "Identity" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slice_0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_1/Cast" - op: "Cast" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slice_0/read" - attr { - key: "DstT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "SrcT" - value { - type: DT_FLOAT - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_1/parallel_0_1/Cast" - op: "Cast" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/read" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_1/parallel_0_1/Assign" - op: "Assign" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slice_0" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_1/parallel_0_1/Cast" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_1/group_deps" - op: "NoOp" - input: "^encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_1/parallel_0_1/Assign" -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_1/Assign" - op: "Assign" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_1/Cast" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wi_1/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wo/kernel_slice_0/Initializer/random_uniform/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wo/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "@\000\000\000 \000\000\000" - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wo/kernel_slice_0/Initializer/random_uniform/min" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wo/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: -0.25 - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wo/kernel_slice_0/Initializer/random_uniform/max" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wo/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.25 - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wo/kernel_slice_0/Initializer/random_uniform/RandomUniform" - op: "RandomUniform" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel_slice_0/Initializer/random_uniform/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wo/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wo/kernel_slice_0/Initializer/random_uniform/sub" - op: "Sub" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel_slice_0/Initializer/random_uniform/max" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wo/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wo/kernel_slice_0/Initializer/random_uniform/mul" - op: "Mul" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel_slice_0/Initializer/random_uniform/RandomUniform" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel_slice_0/Initializer/random_uniform/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wo/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wo/kernel_slice_0/Initializer/random_uniform" - op: "Add" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel_slice_0/Initializer/random_uniform/mul" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wo/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wo/kernel_slice_0" - op: "VariableV2" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wo/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wo/kernel_slice_0/Assign" - op: "Assign" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel_slice_0" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel_slice_0/Initializer/random_uniform" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wo/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wo/kernel_slice_0/read" - op: "Identity" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel_slice_0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wo/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wo/kernel_1/Cast" - op: "Cast" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel_slice_0/read" - attr { - key: "DstT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "SrcT" - value { - type: DT_FLOAT - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wo/kernel_1/parallel_0_1/Cast" - op: "Cast" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/read" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wo/kernel_1/parallel_0_1/Assign" - op: "Assign" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel_slice_0" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel_1/parallel_0_1/Cast" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wo/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wo/kernel_1/group_deps" - op: "NoOp" - input: "^encoder/block_000/layer_001/DenseReluDense/wo/kernel_1/parallel_0_1/Assign" -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wo/kernel_1/Assign" - op: "Assign" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel_1/Cast" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wo/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_001/layer_000/rms_norm/scale_slice_0/Initializer/random_uniform/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 32 - } - } - } -} -node { - name: "encoder/block_001/layer_000/rms_norm/scale_slice_0/Initializer/random_uniform/min" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: -0.3061862289905548 - } - } - } -} -node { - name: "encoder/block_001/layer_000/rms_norm/scale_slice_0/Initializer/random_uniform/max" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.3061862289905548 - } - } - } -} -node { - name: "encoder/block_001/layer_000/rms_norm/scale_slice_0/Initializer/random_uniform/RandomUniform" - op: "RandomUniform" - input: "encoder/block_001/layer_000/rms_norm/scale_slice_0/Initializer/random_uniform/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "encoder/block_001/layer_000/rms_norm/scale_slice_0/Initializer/random_uniform/sub" - op: "Sub" - input: "encoder/block_001/layer_000/rms_norm/scale_slice_0/Initializer/random_uniform/max" - input: "encoder/block_001/layer_000/rms_norm/scale_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/rms_norm/scale_slice_0/Initializer/random_uniform/mul" - op: "Mul" - input: "encoder/block_001/layer_000/rms_norm/scale_slice_0/Initializer/random_uniform/RandomUniform" - input: "encoder/block_001/layer_000/rms_norm/scale_slice_0/Initializer/random_uniform/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/rms_norm/scale_slice_0/Initializer/random_uniform" - op: "Add" - input: "encoder/block_001/layer_000/rms_norm/scale_slice_0/Initializer/random_uniform/mul" - input: "encoder/block_001/layer_000/rms_norm/scale_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/rms_norm/scale_slice_0" - op: "VariableV2" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "encoder/block_001/layer_000/rms_norm/scale_slice_0/Assign" - op: "Assign" - input: "encoder/block_001/layer_000/rms_norm/scale_slice_0" - input: "encoder/block_001/layer_000/rms_norm/scale_slice_0/Initializer/random_uniform" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_001/layer_000/rms_norm/scale_slice_0/read" - op: "Identity" - input: "encoder/block_001/layer_000/rms_norm/scale_slice_0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/rms_norm/scale_1/Cast" - op: "Cast" - input: "encoder/block_001/layer_000/rms_norm/scale_slice_0/read" - attr { - key: "DstT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "SrcT" - value { - type: DT_FLOAT - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/rms_norm/scale_1/parallel_0_1/Cast" - op: "Cast" - input: "encoder/block_001/layer_000/rms_norm/scale/read" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/rms_norm/scale_1/parallel_0_1/Assign" - op: "Assign" - input: "encoder/block_001/layer_000/rms_norm/scale_slice_0" - input: "encoder/block_001/layer_000/rms_norm/scale_1/parallel_0_1/Cast" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_001/layer_000/rms_norm/scale_1/group_deps" - op: "NoOp" - input: "^encoder/block_001/layer_000/rms_norm/scale_1/parallel_0_1/Assign" -} -node { - name: "encoder/block_001/layer_000/rms_norm/scale_1/Assign" - op: "Assign" - input: "encoder/block_001/layer_000/rms_norm/scale" - input: "encoder/block_001/layer_000/rms_norm/scale_1/Cast" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/q_slice_0/Initializer/random_uniform/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/q_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200\000\000\000" - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/q_slice_0/Initializer/random_uniform/min" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/q_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: -0.19364917278289795 - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/q_slice_0/Initializer/random_uniform/max" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/q_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.19364917278289795 - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/q_slice_0/Initializer/random_uniform/RandomUniform" - op: "RandomUniform" - input: "encoder/block_001/layer_000/SelfAttention/q_slice_0/Initializer/random_uniform/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/q_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/q_slice_0/Initializer/random_uniform/sub" - op: "Sub" - input: "encoder/block_001/layer_000/SelfAttention/q_slice_0/Initializer/random_uniform/max" - input: "encoder/block_001/layer_000/SelfAttention/q_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/q_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/q_slice_0/Initializer/random_uniform/mul" - op: "Mul" - input: "encoder/block_001/layer_000/SelfAttention/q_slice_0/Initializer/random_uniform/RandomUniform" - input: "encoder/block_001/layer_000/SelfAttention/q_slice_0/Initializer/random_uniform/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/q_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/q_slice_0/Initializer/random_uniform" - op: "Add" - input: "encoder/block_001/layer_000/SelfAttention/q_slice_0/Initializer/random_uniform/mul" - input: "encoder/block_001/layer_000/SelfAttention/q_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/q_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/q_slice_0" - op: "VariableV2" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/q_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/q_slice_0/Assign" - op: "Assign" - input: "encoder/block_001/layer_000/SelfAttention/q_slice_0" - input: "encoder/block_001/layer_000/SelfAttention/q_slice_0/Initializer/random_uniform" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/q_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/q_slice_0/read" - op: "Identity" - input: "encoder/block_001/layer_000/SelfAttention/q_slice_0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/q_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/q_1/Cast" - op: "Cast" - input: "encoder/block_001/layer_000/SelfAttention/q_slice_0/read" - attr { - key: "DstT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "SrcT" - value { - type: DT_FLOAT - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/q_1/parallel_0_1/Cast" - op: "Cast" - input: "encoder/block_001/layer_000/SelfAttention/q/read" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/q_1/parallel_0_1/Assign" - op: "Assign" - input: "encoder/block_001/layer_000/SelfAttention/q_slice_0" - input: "encoder/block_001/layer_000/SelfAttention/q_1/parallel_0_1/Cast" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/q_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/q_1/group_deps" - op: "NoOp" - input: "^encoder/block_001/layer_000/SelfAttention/q_1/parallel_0_1/Assign" -} -node { - name: "encoder/block_001/layer_000/SelfAttention/q_1/Assign" - op: "Assign" - input: "encoder/block_001/layer_000/SelfAttention/q" - input: "encoder/block_001/layer_000/SelfAttention/q_1/Cast" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/k_slice_0/Initializer/random_uniform/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/k_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200\000\000\000" - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/k_slice_0/Initializer/random_uniform/min" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/k_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: -0.19364917278289795 - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/k_slice_0/Initializer/random_uniform/max" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/k_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.19364917278289795 - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/k_slice_0/Initializer/random_uniform/RandomUniform" - op: "RandomUniform" - input: "encoder/block_001/layer_000/SelfAttention/k_slice_0/Initializer/random_uniform/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/k_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/k_slice_0/Initializer/random_uniform/sub" - op: "Sub" - input: "encoder/block_001/layer_000/SelfAttention/k_slice_0/Initializer/random_uniform/max" - input: "encoder/block_001/layer_000/SelfAttention/k_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/k_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/k_slice_0/Initializer/random_uniform/mul" - op: "Mul" - input: "encoder/block_001/layer_000/SelfAttention/k_slice_0/Initializer/random_uniform/RandomUniform" - input: "encoder/block_001/layer_000/SelfAttention/k_slice_0/Initializer/random_uniform/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/k_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/k_slice_0/Initializer/random_uniform" - op: "Add" - input: "encoder/block_001/layer_000/SelfAttention/k_slice_0/Initializer/random_uniform/mul" - input: "encoder/block_001/layer_000/SelfAttention/k_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/k_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/k_slice_0" - op: "VariableV2" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/k_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/k_slice_0/Assign" - op: "Assign" - input: "encoder/block_001/layer_000/SelfAttention/k_slice_0" - input: "encoder/block_001/layer_000/SelfAttention/k_slice_0/Initializer/random_uniform" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/k_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/k_slice_0/read" - op: "Identity" - input: "encoder/block_001/layer_000/SelfAttention/k_slice_0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/k_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/k_1/Cast" - op: "Cast" - input: "encoder/block_001/layer_000/SelfAttention/k_slice_0/read" - attr { - key: "DstT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "SrcT" - value { - type: DT_FLOAT - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/k_1/parallel_0_1/Cast" - op: "Cast" - input: "encoder/block_001/layer_000/SelfAttention/k/read" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/k_1/parallel_0_1/Assign" - op: "Assign" - input: "encoder/block_001/layer_000/SelfAttention/k_slice_0" - input: "encoder/block_001/layer_000/SelfAttention/k_1/parallel_0_1/Cast" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/k_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/k_1/group_deps" - op: "NoOp" - input: "^encoder/block_001/layer_000/SelfAttention/k_1/parallel_0_1/Assign" -} -node { - name: "encoder/block_001/layer_000/SelfAttention/k_1/Assign" - op: "Assign" - input: "encoder/block_001/layer_000/SelfAttention/k" - input: "encoder/block_001/layer_000/SelfAttention/k_1/Cast" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/v_slice_0/Initializer/random_uniform/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/v_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200\000\000\000" - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/v_slice_0/Initializer/random_uniform/min" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/v_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: -0.19364917278289795 - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/v_slice_0/Initializer/random_uniform/max" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/v_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.19364917278289795 - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/v_slice_0/Initializer/random_uniform/RandomUniform" - op: "RandomUniform" - input: "encoder/block_001/layer_000/SelfAttention/v_slice_0/Initializer/random_uniform/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/v_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/v_slice_0/Initializer/random_uniform/sub" - op: "Sub" - input: "encoder/block_001/layer_000/SelfAttention/v_slice_0/Initializer/random_uniform/max" - input: "encoder/block_001/layer_000/SelfAttention/v_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/v_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/v_slice_0/Initializer/random_uniform/mul" - op: "Mul" - input: "encoder/block_001/layer_000/SelfAttention/v_slice_0/Initializer/random_uniform/RandomUniform" - input: "encoder/block_001/layer_000/SelfAttention/v_slice_0/Initializer/random_uniform/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/v_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/v_slice_0/Initializer/random_uniform" - op: "Add" - input: "encoder/block_001/layer_000/SelfAttention/v_slice_0/Initializer/random_uniform/mul" - input: "encoder/block_001/layer_000/SelfAttention/v_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/v_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/v_slice_0" - op: "VariableV2" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/v_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/v_slice_0/Assign" - op: "Assign" - input: "encoder/block_001/layer_000/SelfAttention/v_slice_0" - input: "encoder/block_001/layer_000/SelfAttention/v_slice_0/Initializer/random_uniform" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/v_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/v_slice_0/read" - op: "Identity" - input: "encoder/block_001/layer_000/SelfAttention/v_slice_0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/v_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/v_1/Cast" - op: "Cast" - input: "encoder/block_001/layer_000/SelfAttention/v_slice_0/read" - attr { - key: "DstT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "SrcT" - value { - type: DT_FLOAT - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/v_1/parallel_0_1/Cast" - op: "Cast" - input: "encoder/block_001/layer_000/SelfAttention/v/read" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/v_1/parallel_0_1/Assign" - op: "Assign" - input: "encoder/block_001/layer_000/SelfAttention/v_slice_0" - input: "encoder/block_001/layer_000/SelfAttention/v_1/parallel_0_1/Cast" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/v_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/v_1/group_deps" - op: "NoOp" - input: "^encoder/block_001/layer_000/SelfAttention/v_1/parallel_0_1/Assign" -} -node { - name: "encoder/block_001/layer_000/SelfAttention/v_1/Assign" - op: "Assign" - input: "encoder/block_001/layer_000/SelfAttention/v" - input: "encoder/block_001/layer_000/SelfAttention/v_1/Cast" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/o_slice_0/Initializer/random_uniform/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/o_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\200\000\000\000 \000\000\000" - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/o_slice_0/Initializer/random_uniform/min" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/o_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: -0.19364917278289795 - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/o_slice_0/Initializer/random_uniform/max" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/o_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.19364917278289795 - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/o_slice_0/Initializer/random_uniform/RandomUniform" - op: "RandomUniform" - input: "encoder/block_001/layer_000/SelfAttention/o_slice_0/Initializer/random_uniform/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/o_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/o_slice_0/Initializer/random_uniform/sub" - op: "Sub" - input: "encoder/block_001/layer_000/SelfAttention/o_slice_0/Initializer/random_uniform/max" - input: "encoder/block_001/layer_000/SelfAttention/o_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/o_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/o_slice_0/Initializer/random_uniform/mul" - op: "Mul" - input: "encoder/block_001/layer_000/SelfAttention/o_slice_0/Initializer/random_uniform/RandomUniform" - input: "encoder/block_001/layer_000/SelfAttention/o_slice_0/Initializer/random_uniform/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/o_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/o_slice_0/Initializer/random_uniform" - op: "Add" - input: "encoder/block_001/layer_000/SelfAttention/o_slice_0/Initializer/random_uniform/mul" - input: "encoder/block_001/layer_000/SelfAttention/o_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/o_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/o_slice_0" - op: "VariableV2" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/o_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/o_slice_0/Assign" - op: "Assign" - input: "encoder/block_001/layer_000/SelfAttention/o_slice_0" - input: "encoder/block_001/layer_000/SelfAttention/o_slice_0/Initializer/random_uniform" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/o_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/o_slice_0/read" - op: "Identity" - input: "encoder/block_001/layer_000/SelfAttention/o_slice_0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/o_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/o_1/Cast" - op: "Cast" - input: "encoder/block_001/layer_000/SelfAttention/o_slice_0/read" - attr { - key: "DstT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "SrcT" - value { - type: DT_FLOAT - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/o_1/parallel_0_1/Cast" - op: "Cast" - input: "encoder/block_001/layer_000/SelfAttention/o/read" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/o_1/parallel_0_1/Assign" - op: "Assign" - input: "encoder/block_001/layer_000/SelfAttention/o_slice_0" - input: "encoder/block_001/layer_000/SelfAttention/o_1/parallel_0_1/Cast" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/o_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/o_1/group_deps" - op: "NoOp" - input: "^encoder/block_001/layer_000/SelfAttention/o_1/parallel_0_1/Assign" -} -node { - name: "encoder/block_001/layer_000/SelfAttention/o_1/Assign" - op: "Assign" - input: "encoder/block_001/layer_000/SelfAttention/o" - input: "encoder/block_001/layer_000/SelfAttention/o_1/Cast" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_001/layer_001/rms_norm/scale_slice_0/Initializer/random_uniform/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 32 - } - } - } -} -node { - name: "encoder/block_001/layer_001/rms_norm/scale_slice_0/Initializer/random_uniform/min" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: -0.3061862289905548 - } - } - } -} -node { - name: "encoder/block_001/layer_001/rms_norm/scale_slice_0/Initializer/random_uniform/max" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.3061862289905548 - } - } - } -} -node { - name: "encoder/block_001/layer_001/rms_norm/scale_slice_0/Initializer/random_uniform/RandomUniform" - op: "RandomUniform" - input: "encoder/block_001/layer_001/rms_norm/scale_slice_0/Initializer/random_uniform/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "encoder/block_001/layer_001/rms_norm/scale_slice_0/Initializer/random_uniform/sub" - op: "Sub" - input: "encoder/block_001/layer_001/rms_norm/scale_slice_0/Initializer/random_uniform/max" - input: "encoder/block_001/layer_001/rms_norm/scale_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/rms_norm/scale_slice_0/Initializer/random_uniform/mul" - op: "Mul" - input: "encoder/block_001/layer_001/rms_norm/scale_slice_0/Initializer/random_uniform/RandomUniform" - input: "encoder/block_001/layer_001/rms_norm/scale_slice_0/Initializer/random_uniform/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/rms_norm/scale_slice_0/Initializer/random_uniform" - op: "Add" - input: "encoder/block_001/layer_001/rms_norm/scale_slice_0/Initializer/random_uniform/mul" - input: "encoder/block_001/layer_001/rms_norm/scale_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/rms_norm/scale_slice_0" - op: "VariableV2" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "encoder/block_001/layer_001/rms_norm/scale_slice_0/Assign" - op: "Assign" - input: "encoder/block_001/layer_001/rms_norm/scale_slice_0" - input: "encoder/block_001/layer_001/rms_norm/scale_slice_0/Initializer/random_uniform" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_001/layer_001/rms_norm/scale_slice_0/read" - op: "Identity" - input: "encoder/block_001/layer_001/rms_norm/scale_slice_0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/rms_norm/scale_1/Cast" - op: "Cast" - input: "encoder/block_001/layer_001/rms_norm/scale_slice_0/read" - attr { - key: "DstT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "SrcT" - value { - type: DT_FLOAT - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/rms_norm/scale_1/parallel_0_1/Cast" - op: "Cast" - input: "encoder/block_001/layer_001/rms_norm/scale/read" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/rms_norm/scale_1/parallel_0_1/Assign" - op: "Assign" - input: "encoder/block_001/layer_001/rms_norm/scale_slice_0" - input: "encoder/block_001/layer_001/rms_norm/scale_1/parallel_0_1/Cast" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_001/layer_001/rms_norm/scale_1/group_deps" - op: "NoOp" - input: "^encoder/block_001/layer_001/rms_norm/scale_1/parallel_0_1/Assign" -} -node { - name: "encoder/block_001/layer_001/rms_norm/scale_1/Assign" - op: "Assign" - input: "encoder/block_001/layer_001/rms_norm/scale" - input: "encoder/block_001/layer_001/rms_norm/scale_1/Cast" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slice_0/Initializer/random_uniform/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000@\000\000\000" - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slice_0/Initializer/random_uniform/min" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: -0.25 - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slice_0/Initializer/random_uniform/max" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.25 - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slice_0/Initializer/random_uniform/RandomUniform" - op: "RandomUniform" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slice_0/Initializer/random_uniform/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slice_0/Initializer/random_uniform/sub" - op: "Sub" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slice_0/Initializer/random_uniform/max" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slice_0/Initializer/random_uniform/mul" - op: "Mul" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slice_0/Initializer/random_uniform/RandomUniform" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slice_0/Initializer/random_uniform/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slice_0/Initializer/random_uniform" - op: "Add" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slice_0/Initializer/random_uniform/mul" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slice_0" - op: "VariableV2" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slice_0/Assign" - op: "Assign" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slice_0" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slice_0/Initializer/random_uniform" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slice_0/read" - op: "Identity" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slice_0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_1/Cast" - op: "Cast" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slice_0/read" - attr { - key: "DstT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "SrcT" - value { - type: DT_FLOAT - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_1/parallel_0_1/Cast" - op: "Cast" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/read" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_1/parallel_0_1/Assign" - op: "Assign" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slice_0" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_1/parallel_0_1/Cast" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_1/group_deps" - op: "NoOp" - input: "^encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_1/parallel_0_1/Assign" -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_1/Assign" - op: "Assign" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_1/Cast" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wi_0/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slice_0/Initializer/random_uniform/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000@\000\000\000" - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slice_0/Initializer/random_uniform/min" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: -0.25 - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slice_0/Initializer/random_uniform/max" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.25 - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slice_0/Initializer/random_uniform/RandomUniform" - op: "RandomUniform" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slice_0/Initializer/random_uniform/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slice_0/Initializer/random_uniform/sub" - op: "Sub" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slice_0/Initializer/random_uniform/max" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slice_0/Initializer/random_uniform/mul" - op: "Mul" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slice_0/Initializer/random_uniform/RandomUniform" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slice_0/Initializer/random_uniform/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slice_0/Initializer/random_uniform" - op: "Add" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slice_0/Initializer/random_uniform/mul" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slice_0" - op: "VariableV2" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slice_0/Assign" - op: "Assign" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slice_0" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slice_0/Initializer/random_uniform" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slice_0/read" - op: "Identity" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slice_0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_1/Cast" - op: "Cast" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slice_0/read" - attr { - key: "DstT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "SrcT" - value { - type: DT_FLOAT - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_1/parallel_0_1/Cast" - op: "Cast" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/read" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_1/parallel_0_1/Assign" - op: "Assign" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slice_0" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_1/parallel_0_1/Cast" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_1/group_deps" - op: "NoOp" - input: "^encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_1/parallel_0_1/Assign" -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_1/Assign" - op: "Assign" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_1/Cast" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wi_1/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wo/kernel_slice_0/Initializer/random_uniform/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wo/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "@\000\000\000 \000\000\000" - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wo/kernel_slice_0/Initializer/random_uniform/min" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wo/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: -0.25 - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wo/kernel_slice_0/Initializer/random_uniform/max" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wo/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.25 - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wo/kernel_slice_0/Initializer/random_uniform/RandomUniform" - op: "RandomUniform" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel_slice_0/Initializer/random_uniform/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wo/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wo/kernel_slice_0/Initializer/random_uniform/sub" - op: "Sub" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel_slice_0/Initializer/random_uniform/max" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wo/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wo/kernel_slice_0/Initializer/random_uniform/mul" - op: "Mul" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel_slice_0/Initializer/random_uniform/RandomUniform" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel_slice_0/Initializer/random_uniform/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wo/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wo/kernel_slice_0/Initializer/random_uniform" - op: "Add" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel_slice_0/Initializer/random_uniform/mul" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wo/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wo/kernel_slice_0" - op: "VariableV2" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wo/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wo/kernel_slice_0/Assign" - op: "Assign" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel_slice_0" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel_slice_0/Initializer/random_uniform" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wo/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wo/kernel_slice_0/read" - op: "Identity" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel_slice_0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wo/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wo/kernel_1/Cast" - op: "Cast" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel_slice_0/read" - attr { - key: "DstT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "SrcT" - value { - type: DT_FLOAT - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wo/kernel_1/parallel_0_1/Cast" - op: "Cast" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/read" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wo/kernel_1/parallel_0_1/Assign" - op: "Assign" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel_slice_0" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel_1/parallel_0_1/Cast" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wo/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wo/kernel_1/group_deps" - op: "NoOp" - input: "^encoder/block_001/layer_001/DenseReluDense/wo/kernel_1/parallel_0_1/Assign" -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wo/kernel_1/Assign" - op: "Assign" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel_1/Cast" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wo/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/rms_norm/scale_slice_0/Initializer/random_uniform/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 32 - } - } - } -} -node { - name: "encoder/rms_norm/scale_slice_0/Initializer/random_uniform/min" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: -0.3061862289905548 - } - } - } -} -node { - name: "encoder/rms_norm/scale_slice_0/Initializer/random_uniform/max" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.3061862289905548 - } - } - } -} -node { - name: "encoder/rms_norm/scale_slice_0/Initializer/random_uniform/RandomUniform" - op: "RandomUniform" - input: "encoder/rms_norm/scale_slice_0/Initializer/random_uniform/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "encoder/rms_norm/scale_slice_0/Initializer/random_uniform/sub" - op: "Sub" - input: "encoder/rms_norm/scale_slice_0/Initializer/random_uniform/max" - input: "encoder/rms_norm/scale_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/rms_norm/scale_slice_0/Initializer/random_uniform/mul" - op: "Mul" - input: "encoder/rms_norm/scale_slice_0/Initializer/random_uniform/RandomUniform" - input: "encoder/rms_norm/scale_slice_0/Initializer/random_uniform/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/rms_norm/scale_slice_0/Initializer/random_uniform" - op: "Add" - input: "encoder/rms_norm/scale_slice_0/Initializer/random_uniform/mul" - input: "encoder/rms_norm/scale_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/rms_norm/scale_slice_0" - op: "VariableV2" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "encoder/rms_norm/scale_slice_0/Assign" - op: "Assign" - input: "encoder/rms_norm/scale_slice_0" - input: "encoder/rms_norm/scale_slice_0/Initializer/random_uniform" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/rms_norm/scale_slice_0/read" - op: "Identity" - input: "encoder/rms_norm/scale_slice_0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/rms_norm/scale_1/Cast" - op: "Cast" - input: "encoder/rms_norm/scale_slice_0/read" - attr { - key: "DstT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "SrcT" - value { - type: DT_FLOAT - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/rms_norm/scale_1/parallel_0_1/Cast" - op: "Cast" - input: "encoder/rms_norm/scale/read" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/rms_norm/scale_1/parallel_0_1/Assign" - op: "Assign" - input: "encoder/rms_norm/scale_slice_0" - input: "encoder/rms_norm/scale_1/parallel_0_1/Cast" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "encoder/rms_norm/scale_1/group_deps" - op: "NoOp" - input: "^encoder/rms_norm/scale_1/parallel_0_1/Assign" -} -node { - name: "encoder/rms_norm/scale_1/Assign" - op: "Assign" - input: "encoder/rms_norm/scale" - input: "encoder/rms_norm/scale_1/Cast" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_000/rms_norm/scale_slice_0/Initializer/random_uniform/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 32 - } - } - } -} -node { - name: "decoder/block_000/layer_000/rms_norm/scale_slice_0/Initializer/random_uniform/min" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: -0.3061862289905548 - } - } - } -} -node { - name: "decoder/block_000/layer_000/rms_norm/scale_slice_0/Initializer/random_uniform/max" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.3061862289905548 - } - } - } -} -node { - name: "decoder/block_000/layer_000/rms_norm/scale_slice_0/Initializer/random_uniform/RandomUniform" - op: "RandomUniform" - input: "decoder/block_000/layer_000/rms_norm/scale_slice_0/Initializer/random_uniform/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "decoder/block_000/layer_000/rms_norm/scale_slice_0/Initializer/random_uniform/sub" - op: "Sub" - input: "decoder/block_000/layer_000/rms_norm/scale_slice_0/Initializer/random_uniform/max" - input: "decoder/block_000/layer_000/rms_norm/scale_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/rms_norm/scale_slice_0/Initializer/random_uniform/mul" - op: "Mul" - input: "decoder/block_000/layer_000/rms_norm/scale_slice_0/Initializer/random_uniform/RandomUniform" - input: "decoder/block_000/layer_000/rms_norm/scale_slice_0/Initializer/random_uniform/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/rms_norm/scale_slice_0/Initializer/random_uniform" - op: "Add" - input: "decoder/block_000/layer_000/rms_norm/scale_slice_0/Initializer/random_uniform/mul" - input: "decoder/block_000/layer_000/rms_norm/scale_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/rms_norm/scale_slice_0" - op: "VariableV2" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_000/layer_000/rms_norm/scale_slice_0/Assign" - op: "Assign" - input: "decoder/block_000/layer_000/rms_norm/scale_slice_0" - input: "decoder/block_000/layer_000/rms_norm/scale_slice_0/Initializer/random_uniform" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_000/rms_norm/scale_slice_0/read" - op: "Identity" - input: "decoder/block_000/layer_000/rms_norm/scale_slice_0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/rms_norm/scale_1/Cast" - op: "Cast" - input: "decoder/block_000/layer_000/rms_norm/scale_slice_0/read" - attr { - key: "DstT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "SrcT" - value { - type: DT_FLOAT - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/rms_norm/scale_1/parallel_0_1/Cast" - op: "Cast" - input: "decoder/block_000/layer_000/rms_norm/scale/read" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/rms_norm/scale_1/parallel_0_1/Assign" - op: "Assign" - input: "decoder/block_000/layer_000/rms_norm/scale_slice_0" - input: "decoder/block_000/layer_000/rms_norm/scale_1/parallel_0_1/Cast" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_000/rms_norm/scale_1/group_deps" - op: "NoOp" - input: "^decoder/block_000/layer_000/rms_norm/scale_1/parallel_0_1/Assign" -} -node { - name: "decoder/block_000/layer_000/rms_norm/scale_1/Assign" - op: "Assign" - input: "decoder/block_000/layer_000/rms_norm/scale" - input: "decoder/block_000/layer_000/rms_norm/scale_1/Cast" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/q_slice_0/Initializer/random_uniform/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/q_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200\000\000\000" - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/q_slice_0/Initializer/random_uniform/min" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/q_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: -0.19364917278289795 - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/q_slice_0/Initializer/random_uniform/max" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/q_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.19364917278289795 - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/q_slice_0/Initializer/random_uniform/RandomUniform" - op: "RandomUniform" - input: "decoder/block_000/layer_000/SelfAttention/q_slice_0/Initializer/random_uniform/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/q_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/q_slice_0/Initializer/random_uniform/sub" - op: "Sub" - input: "decoder/block_000/layer_000/SelfAttention/q_slice_0/Initializer/random_uniform/max" - input: "decoder/block_000/layer_000/SelfAttention/q_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/q_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/q_slice_0/Initializer/random_uniform/mul" - op: "Mul" - input: "decoder/block_000/layer_000/SelfAttention/q_slice_0/Initializer/random_uniform/RandomUniform" - input: "decoder/block_000/layer_000/SelfAttention/q_slice_0/Initializer/random_uniform/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/q_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/q_slice_0/Initializer/random_uniform" - op: "Add" - input: "decoder/block_000/layer_000/SelfAttention/q_slice_0/Initializer/random_uniform/mul" - input: "decoder/block_000/layer_000/SelfAttention/q_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/q_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/q_slice_0" - op: "VariableV2" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/q_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/q_slice_0/Assign" - op: "Assign" - input: "decoder/block_000/layer_000/SelfAttention/q_slice_0" - input: "decoder/block_000/layer_000/SelfAttention/q_slice_0/Initializer/random_uniform" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/q_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/q_slice_0/read" - op: "Identity" - input: "decoder/block_000/layer_000/SelfAttention/q_slice_0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/q_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/q_1/Cast" - op: "Cast" - input: "decoder/block_000/layer_000/SelfAttention/q_slice_0/read" - attr { - key: "DstT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "SrcT" - value { - type: DT_FLOAT - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/q_1/parallel_0_1/Cast" - op: "Cast" - input: "decoder/block_000/layer_000/SelfAttention/q/read" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/q_1/parallel_0_1/Assign" - op: "Assign" - input: "decoder/block_000/layer_000/SelfAttention/q_slice_0" - input: "decoder/block_000/layer_000/SelfAttention/q_1/parallel_0_1/Cast" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/q_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/q_1/group_deps" - op: "NoOp" - input: "^decoder/block_000/layer_000/SelfAttention/q_1/parallel_0_1/Assign" -} -node { - name: "decoder/block_000/layer_000/SelfAttention/q_1/Assign" - op: "Assign" - input: "decoder/block_000/layer_000/SelfAttention/q" - input: "decoder/block_000/layer_000/SelfAttention/q_1/Cast" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/k_slice_0/Initializer/random_uniform/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/k_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200\000\000\000" - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/k_slice_0/Initializer/random_uniform/min" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/k_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: -0.19364917278289795 - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/k_slice_0/Initializer/random_uniform/max" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/k_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.19364917278289795 - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/k_slice_0/Initializer/random_uniform/RandomUniform" - op: "RandomUniform" - input: "decoder/block_000/layer_000/SelfAttention/k_slice_0/Initializer/random_uniform/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/k_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/k_slice_0/Initializer/random_uniform/sub" - op: "Sub" - input: "decoder/block_000/layer_000/SelfAttention/k_slice_0/Initializer/random_uniform/max" - input: "decoder/block_000/layer_000/SelfAttention/k_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/k_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/k_slice_0/Initializer/random_uniform/mul" - op: "Mul" - input: "decoder/block_000/layer_000/SelfAttention/k_slice_0/Initializer/random_uniform/RandomUniform" - input: "decoder/block_000/layer_000/SelfAttention/k_slice_0/Initializer/random_uniform/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/k_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/k_slice_0/Initializer/random_uniform" - op: "Add" - input: "decoder/block_000/layer_000/SelfAttention/k_slice_0/Initializer/random_uniform/mul" - input: "decoder/block_000/layer_000/SelfAttention/k_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/k_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/k_slice_0" - op: "VariableV2" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/k_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/k_slice_0/Assign" - op: "Assign" - input: "decoder/block_000/layer_000/SelfAttention/k_slice_0" - input: "decoder/block_000/layer_000/SelfAttention/k_slice_0/Initializer/random_uniform" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/k_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/k_slice_0/read" - op: "Identity" - input: "decoder/block_000/layer_000/SelfAttention/k_slice_0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/k_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/k_1/Cast" - op: "Cast" - input: "decoder/block_000/layer_000/SelfAttention/k_slice_0/read" - attr { - key: "DstT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "SrcT" - value { - type: DT_FLOAT - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/k_1/parallel_0_1/Cast" - op: "Cast" - input: "decoder/block_000/layer_000/SelfAttention/k/read" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/k_1/parallel_0_1/Assign" - op: "Assign" - input: "decoder/block_000/layer_000/SelfAttention/k_slice_0" - input: "decoder/block_000/layer_000/SelfAttention/k_1/parallel_0_1/Cast" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/k_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/k_1/group_deps" - op: "NoOp" - input: "^decoder/block_000/layer_000/SelfAttention/k_1/parallel_0_1/Assign" -} -node { - name: "decoder/block_000/layer_000/SelfAttention/k_1/Assign" - op: "Assign" - input: "decoder/block_000/layer_000/SelfAttention/k" - input: "decoder/block_000/layer_000/SelfAttention/k_1/Cast" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/v_slice_0/Initializer/random_uniform/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/v_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200\000\000\000" - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/v_slice_0/Initializer/random_uniform/min" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/v_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: -0.19364917278289795 - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/v_slice_0/Initializer/random_uniform/max" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/v_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.19364917278289795 - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/v_slice_0/Initializer/random_uniform/RandomUniform" - op: "RandomUniform" - input: "decoder/block_000/layer_000/SelfAttention/v_slice_0/Initializer/random_uniform/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/v_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/v_slice_0/Initializer/random_uniform/sub" - op: "Sub" - input: "decoder/block_000/layer_000/SelfAttention/v_slice_0/Initializer/random_uniform/max" - input: "decoder/block_000/layer_000/SelfAttention/v_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/v_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/v_slice_0/Initializer/random_uniform/mul" - op: "Mul" - input: "decoder/block_000/layer_000/SelfAttention/v_slice_0/Initializer/random_uniform/RandomUniform" - input: "decoder/block_000/layer_000/SelfAttention/v_slice_0/Initializer/random_uniform/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/v_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/v_slice_0/Initializer/random_uniform" - op: "Add" - input: "decoder/block_000/layer_000/SelfAttention/v_slice_0/Initializer/random_uniform/mul" - input: "decoder/block_000/layer_000/SelfAttention/v_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/v_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/v_slice_0" - op: "VariableV2" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/v_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/v_slice_0/Assign" - op: "Assign" - input: "decoder/block_000/layer_000/SelfAttention/v_slice_0" - input: "decoder/block_000/layer_000/SelfAttention/v_slice_0/Initializer/random_uniform" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/v_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/v_slice_0/read" - op: "Identity" - input: "decoder/block_000/layer_000/SelfAttention/v_slice_0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/v_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/v_1/Cast" - op: "Cast" - input: "decoder/block_000/layer_000/SelfAttention/v_slice_0/read" - attr { - key: "DstT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "SrcT" - value { - type: DT_FLOAT - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/v_1/parallel_0_1/Cast" - op: "Cast" - input: "decoder/block_000/layer_000/SelfAttention/v/read" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/v_1/parallel_0_1/Assign" - op: "Assign" - input: "decoder/block_000/layer_000/SelfAttention/v_slice_0" - input: "decoder/block_000/layer_000/SelfAttention/v_1/parallel_0_1/Cast" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/v_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/v_1/group_deps" - op: "NoOp" - input: "^decoder/block_000/layer_000/SelfAttention/v_1/parallel_0_1/Assign" -} -node { - name: "decoder/block_000/layer_000/SelfAttention/v_1/Assign" - op: "Assign" - input: "decoder/block_000/layer_000/SelfAttention/v" - input: "decoder/block_000/layer_000/SelfAttention/v_1/Cast" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/o_slice_0/Initializer/random_uniform/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/o_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\200\000\000\000 \000\000\000" - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/o_slice_0/Initializer/random_uniform/min" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/o_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: -0.19364917278289795 - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/o_slice_0/Initializer/random_uniform/max" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/o_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.19364917278289795 - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/o_slice_0/Initializer/random_uniform/RandomUniform" - op: "RandomUniform" - input: "decoder/block_000/layer_000/SelfAttention/o_slice_0/Initializer/random_uniform/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/o_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/o_slice_0/Initializer/random_uniform/sub" - op: "Sub" - input: "decoder/block_000/layer_000/SelfAttention/o_slice_0/Initializer/random_uniform/max" - input: "decoder/block_000/layer_000/SelfAttention/o_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/o_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/o_slice_0/Initializer/random_uniform/mul" - op: "Mul" - input: "decoder/block_000/layer_000/SelfAttention/o_slice_0/Initializer/random_uniform/RandomUniform" - input: "decoder/block_000/layer_000/SelfAttention/o_slice_0/Initializer/random_uniform/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/o_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/o_slice_0/Initializer/random_uniform" - op: "Add" - input: "decoder/block_000/layer_000/SelfAttention/o_slice_0/Initializer/random_uniform/mul" - input: "decoder/block_000/layer_000/SelfAttention/o_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/o_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/o_slice_0" - op: "VariableV2" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/o_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/o_slice_0/Assign" - op: "Assign" - input: "decoder/block_000/layer_000/SelfAttention/o_slice_0" - input: "decoder/block_000/layer_000/SelfAttention/o_slice_0/Initializer/random_uniform" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/o_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/o_slice_0/read" - op: "Identity" - input: "decoder/block_000/layer_000/SelfAttention/o_slice_0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/o_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/o_1/Cast" - op: "Cast" - input: "decoder/block_000/layer_000/SelfAttention/o_slice_0/read" - attr { - key: "DstT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "SrcT" - value { - type: DT_FLOAT - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/o_1/parallel_0_1/Cast" - op: "Cast" - input: "decoder/block_000/layer_000/SelfAttention/o/read" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/o_1/parallel_0_1/Assign" - op: "Assign" - input: "decoder/block_000/layer_000/SelfAttention/o_slice_0" - input: "decoder/block_000/layer_000/SelfAttention/o_1/parallel_0_1/Cast" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/o_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/o_1/group_deps" - op: "NoOp" - input: "^decoder/block_000/layer_000/SelfAttention/o_1/parallel_0_1/Assign" -} -node { - name: "decoder/block_000/layer_000/SelfAttention/o_1/Assign" - op: "Assign" - input: "decoder/block_000/layer_000/SelfAttention/o" - input: "decoder/block_000/layer_000/SelfAttention/o_1/Cast" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0/Initializer/random_uniform/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\002\000\000\000 \000\000\000" - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0/Initializer/random_uniform/min" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: -0.42008402943611145 - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0/Initializer/random_uniform/max" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.42008402943611145 - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0/Initializer/random_uniform/RandomUniform" - op: "RandomUniform" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0/Initializer/random_uniform/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0/Initializer/random_uniform/sub" - op: "Sub" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0/Initializer/random_uniform/max" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0/Initializer/random_uniform/mul" - op: "Mul" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0/Initializer/random_uniform/RandomUniform" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0/Initializer/random_uniform/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0/Initializer/random_uniform" - op: "Add" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0/Initializer/random_uniform/mul" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0" - op: "VariableV2" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0/Assign" - op: "Assign" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0/Initializer/random_uniform" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0/read" - op: "Identity" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias_1/Cast" - op: "Cast" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0/read" - attr { - key: "DstT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "SrcT" - value { - type: DT_FLOAT - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias_1/parallel_0_1/Cast" - op: "Cast" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/read" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias_1/parallel_0_1/Assign" - op: "Assign" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias_1/parallel_0_1/Cast" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias_1/group_deps" - op: "NoOp" - input: "^decoder/block_000/layer_000/SelfAttention/relative_attention_bias_1/parallel_0_1/Assign" -} -node { - name: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias_1/Assign" - op: "Assign" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias_1/Cast" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/relative_attention_bias" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_001/rms_norm/scale_slice_0/Initializer/random_uniform/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 32 - } - } - } -} -node { - name: "decoder/block_000/layer_001/rms_norm/scale_slice_0/Initializer/random_uniform/min" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: -0.3061862289905548 - } - } - } -} -node { - name: "decoder/block_000/layer_001/rms_norm/scale_slice_0/Initializer/random_uniform/max" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.3061862289905548 - } - } - } -} -node { - name: "decoder/block_000/layer_001/rms_norm/scale_slice_0/Initializer/random_uniform/RandomUniform" - op: "RandomUniform" - input: "decoder/block_000/layer_001/rms_norm/scale_slice_0/Initializer/random_uniform/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "decoder/block_000/layer_001/rms_norm/scale_slice_0/Initializer/random_uniform/sub" - op: "Sub" - input: "decoder/block_000/layer_001/rms_norm/scale_slice_0/Initializer/random_uniform/max" - input: "decoder/block_000/layer_001/rms_norm/scale_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/rms_norm/scale_slice_0/Initializer/random_uniform/mul" - op: "Mul" - input: "decoder/block_000/layer_001/rms_norm/scale_slice_0/Initializer/random_uniform/RandomUniform" - input: "decoder/block_000/layer_001/rms_norm/scale_slice_0/Initializer/random_uniform/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/rms_norm/scale_slice_0/Initializer/random_uniform" - op: "Add" - input: "decoder/block_000/layer_001/rms_norm/scale_slice_0/Initializer/random_uniform/mul" - input: "decoder/block_000/layer_001/rms_norm/scale_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/rms_norm/scale_slice_0" - op: "VariableV2" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_000/layer_001/rms_norm/scale_slice_0/Assign" - op: "Assign" - input: "decoder/block_000/layer_001/rms_norm/scale_slice_0" - input: "decoder/block_000/layer_001/rms_norm/scale_slice_0/Initializer/random_uniform" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_001/rms_norm/scale_slice_0/read" - op: "Identity" - input: "decoder/block_000/layer_001/rms_norm/scale_slice_0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/rms_norm/scale_1/Cast" - op: "Cast" - input: "decoder/block_000/layer_001/rms_norm/scale_slice_0/read" - attr { - key: "DstT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "SrcT" - value { - type: DT_FLOAT - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/rms_norm/scale_1/parallel_0_1/Cast" - op: "Cast" - input: "decoder/block_000/layer_001/rms_norm/scale/read" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/rms_norm/scale_1/parallel_0_1/Assign" - op: "Assign" - input: "decoder/block_000/layer_001/rms_norm/scale_slice_0" - input: "decoder/block_000/layer_001/rms_norm/scale_1/parallel_0_1/Cast" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_001/rms_norm/scale_1/group_deps" - op: "NoOp" - input: "^decoder/block_000/layer_001/rms_norm/scale_1/parallel_0_1/Assign" -} -node { - name: "decoder/block_000/layer_001/rms_norm/scale_1/Assign" - op: "Assign" - input: "decoder/block_000/layer_001/rms_norm/scale" - input: "decoder/block_000/layer_001/rms_norm/scale_1/Cast" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/q_slice_0/Initializer/random_uniform/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/q_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200\000\000\000" - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/q_slice_0/Initializer/random_uniform/min" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/q_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: -0.19364917278289795 - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/q_slice_0/Initializer/random_uniform/max" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/q_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.19364917278289795 - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/q_slice_0/Initializer/random_uniform/RandomUniform" - op: "RandomUniform" - input: "decoder/block_000/layer_001/EncDecAttention/q_slice_0/Initializer/random_uniform/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/q_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/q_slice_0/Initializer/random_uniform/sub" - op: "Sub" - input: "decoder/block_000/layer_001/EncDecAttention/q_slice_0/Initializer/random_uniform/max" - input: "decoder/block_000/layer_001/EncDecAttention/q_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/q_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/q_slice_0/Initializer/random_uniform/mul" - op: "Mul" - input: "decoder/block_000/layer_001/EncDecAttention/q_slice_0/Initializer/random_uniform/RandomUniform" - input: "decoder/block_000/layer_001/EncDecAttention/q_slice_0/Initializer/random_uniform/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/q_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/q_slice_0/Initializer/random_uniform" - op: "Add" - input: "decoder/block_000/layer_001/EncDecAttention/q_slice_0/Initializer/random_uniform/mul" - input: "decoder/block_000/layer_001/EncDecAttention/q_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/q_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/q_slice_0" - op: "VariableV2" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/q_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/q_slice_0/Assign" - op: "Assign" - input: "decoder/block_000/layer_001/EncDecAttention/q_slice_0" - input: "decoder/block_000/layer_001/EncDecAttention/q_slice_0/Initializer/random_uniform" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/q_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/q_slice_0/read" - op: "Identity" - input: "decoder/block_000/layer_001/EncDecAttention/q_slice_0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/q_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/q_1/Cast" - op: "Cast" - input: "decoder/block_000/layer_001/EncDecAttention/q_slice_0/read" - attr { - key: "DstT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "SrcT" - value { - type: DT_FLOAT - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/q_1/parallel_0_1/Cast" - op: "Cast" - input: "decoder/block_000/layer_001/EncDecAttention/q/read" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/q_1/parallel_0_1/Assign" - op: "Assign" - input: "decoder/block_000/layer_001/EncDecAttention/q_slice_0" - input: "decoder/block_000/layer_001/EncDecAttention/q_1/parallel_0_1/Cast" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/q_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/q_1/group_deps" - op: "NoOp" - input: "^decoder/block_000/layer_001/EncDecAttention/q_1/parallel_0_1/Assign" -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/q_1/Assign" - op: "Assign" - input: "decoder/block_000/layer_001/EncDecAttention/q" - input: "decoder/block_000/layer_001/EncDecAttention/q_1/Cast" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/k_slice_0/Initializer/random_uniform/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/k_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200\000\000\000" - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/k_slice_0/Initializer/random_uniform/min" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/k_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: -0.19364917278289795 - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/k_slice_0/Initializer/random_uniform/max" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/k_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.19364917278289795 - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/k_slice_0/Initializer/random_uniform/RandomUniform" - op: "RandomUniform" - input: "decoder/block_000/layer_001/EncDecAttention/k_slice_0/Initializer/random_uniform/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/k_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/k_slice_0/Initializer/random_uniform/sub" - op: "Sub" - input: "decoder/block_000/layer_001/EncDecAttention/k_slice_0/Initializer/random_uniform/max" - input: "decoder/block_000/layer_001/EncDecAttention/k_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/k_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/k_slice_0/Initializer/random_uniform/mul" - op: "Mul" - input: "decoder/block_000/layer_001/EncDecAttention/k_slice_0/Initializer/random_uniform/RandomUniform" - input: "decoder/block_000/layer_001/EncDecAttention/k_slice_0/Initializer/random_uniform/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/k_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/k_slice_0/Initializer/random_uniform" - op: "Add" - input: "decoder/block_000/layer_001/EncDecAttention/k_slice_0/Initializer/random_uniform/mul" - input: "decoder/block_000/layer_001/EncDecAttention/k_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/k_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/k_slice_0" - op: "VariableV2" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/k_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/k_slice_0/Assign" - op: "Assign" - input: "decoder/block_000/layer_001/EncDecAttention/k_slice_0" - input: "decoder/block_000/layer_001/EncDecAttention/k_slice_0/Initializer/random_uniform" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/k_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/k_slice_0/read" - op: "Identity" - input: "decoder/block_000/layer_001/EncDecAttention/k_slice_0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/k_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/k_1/Cast" - op: "Cast" - input: "decoder/block_000/layer_001/EncDecAttention/k_slice_0/read" - attr { - key: "DstT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "SrcT" - value { - type: DT_FLOAT - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/k_1/parallel_0_1/Cast" - op: "Cast" - input: "decoder/block_000/layer_001/EncDecAttention/k/read" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/k_1/parallel_0_1/Assign" - op: "Assign" - input: "decoder/block_000/layer_001/EncDecAttention/k_slice_0" - input: "decoder/block_000/layer_001/EncDecAttention/k_1/parallel_0_1/Cast" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/k_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/k_1/group_deps" - op: "NoOp" - input: "^decoder/block_000/layer_001/EncDecAttention/k_1/parallel_0_1/Assign" -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/k_1/Assign" - op: "Assign" - input: "decoder/block_000/layer_001/EncDecAttention/k" - input: "decoder/block_000/layer_001/EncDecAttention/k_1/Cast" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/v_slice_0/Initializer/random_uniform/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/v_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200\000\000\000" - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/v_slice_0/Initializer/random_uniform/min" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/v_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: -0.19364917278289795 - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/v_slice_0/Initializer/random_uniform/max" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/v_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.19364917278289795 - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/v_slice_0/Initializer/random_uniform/RandomUniform" - op: "RandomUniform" - input: "decoder/block_000/layer_001/EncDecAttention/v_slice_0/Initializer/random_uniform/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/v_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/v_slice_0/Initializer/random_uniform/sub" - op: "Sub" - input: "decoder/block_000/layer_001/EncDecAttention/v_slice_0/Initializer/random_uniform/max" - input: "decoder/block_000/layer_001/EncDecAttention/v_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/v_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/v_slice_0/Initializer/random_uniform/mul" - op: "Mul" - input: "decoder/block_000/layer_001/EncDecAttention/v_slice_0/Initializer/random_uniform/RandomUniform" - input: "decoder/block_000/layer_001/EncDecAttention/v_slice_0/Initializer/random_uniform/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/v_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/v_slice_0/Initializer/random_uniform" - op: "Add" - input: "decoder/block_000/layer_001/EncDecAttention/v_slice_0/Initializer/random_uniform/mul" - input: "decoder/block_000/layer_001/EncDecAttention/v_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/v_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/v_slice_0" - op: "VariableV2" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/v_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/v_slice_0/Assign" - op: "Assign" - input: "decoder/block_000/layer_001/EncDecAttention/v_slice_0" - input: "decoder/block_000/layer_001/EncDecAttention/v_slice_0/Initializer/random_uniform" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/v_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/v_slice_0/read" - op: "Identity" - input: "decoder/block_000/layer_001/EncDecAttention/v_slice_0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/v_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/v_1/Cast" - op: "Cast" - input: "decoder/block_000/layer_001/EncDecAttention/v_slice_0/read" - attr { - key: "DstT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "SrcT" - value { - type: DT_FLOAT - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/v_1/parallel_0_1/Cast" - op: "Cast" - input: "decoder/block_000/layer_001/EncDecAttention/v/read" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/v_1/parallel_0_1/Assign" - op: "Assign" - input: "decoder/block_000/layer_001/EncDecAttention/v_slice_0" - input: "decoder/block_000/layer_001/EncDecAttention/v_1/parallel_0_1/Cast" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/v_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/v_1/group_deps" - op: "NoOp" - input: "^decoder/block_000/layer_001/EncDecAttention/v_1/parallel_0_1/Assign" -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/v_1/Assign" - op: "Assign" - input: "decoder/block_000/layer_001/EncDecAttention/v" - input: "decoder/block_000/layer_001/EncDecAttention/v_1/Cast" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/o_slice_0/Initializer/random_uniform/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/o_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\200\000\000\000 \000\000\000" - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/o_slice_0/Initializer/random_uniform/min" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/o_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: -0.19364917278289795 - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/o_slice_0/Initializer/random_uniform/max" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/o_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.19364917278289795 - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/o_slice_0/Initializer/random_uniform/RandomUniform" - op: "RandomUniform" - input: "decoder/block_000/layer_001/EncDecAttention/o_slice_0/Initializer/random_uniform/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/o_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/o_slice_0/Initializer/random_uniform/sub" - op: "Sub" - input: "decoder/block_000/layer_001/EncDecAttention/o_slice_0/Initializer/random_uniform/max" - input: "decoder/block_000/layer_001/EncDecAttention/o_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/o_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/o_slice_0/Initializer/random_uniform/mul" - op: "Mul" - input: "decoder/block_000/layer_001/EncDecAttention/o_slice_0/Initializer/random_uniform/RandomUniform" - input: "decoder/block_000/layer_001/EncDecAttention/o_slice_0/Initializer/random_uniform/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/o_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/o_slice_0/Initializer/random_uniform" - op: "Add" - input: "decoder/block_000/layer_001/EncDecAttention/o_slice_0/Initializer/random_uniform/mul" - input: "decoder/block_000/layer_001/EncDecAttention/o_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/o_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/o_slice_0" - op: "VariableV2" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/o_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/o_slice_0/Assign" - op: "Assign" - input: "decoder/block_000/layer_001/EncDecAttention/o_slice_0" - input: "decoder/block_000/layer_001/EncDecAttention/o_slice_0/Initializer/random_uniform" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/o_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/o_slice_0/read" - op: "Identity" - input: "decoder/block_000/layer_001/EncDecAttention/o_slice_0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/o_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/o_1/Cast" - op: "Cast" - input: "decoder/block_000/layer_001/EncDecAttention/o_slice_0/read" - attr { - key: "DstT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "SrcT" - value { - type: DT_FLOAT - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/o_1/parallel_0_1/Cast" - op: "Cast" - input: "decoder/block_000/layer_001/EncDecAttention/o/read" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/o_1/parallel_0_1/Assign" - op: "Assign" - input: "decoder/block_000/layer_001/EncDecAttention/o_slice_0" - input: "decoder/block_000/layer_001/EncDecAttention/o_1/parallel_0_1/Cast" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/o_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/o_1/group_deps" - op: "NoOp" - input: "^decoder/block_000/layer_001/EncDecAttention/o_1/parallel_0_1/Assign" -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/o_1/Assign" - op: "Assign" - input: "decoder/block_000/layer_001/EncDecAttention/o" - input: "decoder/block_000/layer_001/EncDecAttention/o_1/Cast" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_002/rms_norm/scale_slice_0/Initializer/random_uniform/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 32 - } - } - } -} -node { - name: "decoder/block_000/layer_002/rms_norm/scale_slice_0/Initializer/random_uniform/min" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: -0.3061862289905548 - } - } - } -} -node { - name: "decoder/block_000/layer_002/rms_norm/scale_slice_0/Initializer/random_uniform/max" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.3061862289905548 - } - } - } -} -node { - name: "decoder/block_000/layer_002/rms_norm/scale_slice_0/Initializer/random_uniform/RandomUniform" - op: "RandomUniform" - input: "decoder/block_000/layer_002/rms_norm/scale_slice_0/Initializer/random_uniform/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "decoder/block_000/layer_002/rms_norm/scale_slice_0/Initializer/random_uniform/sub" - op: "Sub" - input: "decoder/block_000/layer_002/rms_norm/scale_slice_0/Initializer/random_uniform/max" - input: "decoder/block_000/layer_002/rms_norm/scale_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/rms_norm/scale_slice_0/Initializer/random_uniform/mul" - op: "Mul" - input: "decoder/block_000/layer_002/rms_norm/scale_slice_0/Initializer/random_uniform/RandomUniform" - input: "decoder/block_000/layer_002/rms_norm/scale_slice_0/Initializer/random_uniform/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/rms_norm/scale_slice_0/Initializer/random_uniform" - op: "Add" - input: "decoder/block_000/layer_002/rms_norm/scale_slice_0/Initializer/random_uniform/mul" - input: "decoder/block_000/layer_002/rms_norm/scale_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/rms_norm/scale_slice_0" - op: "VariableV2" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_000/layer_002/rms_norm/scale_slice_0/Assign" - op: "Assign" - input: "decoder/block_000/layer_002/rms_norm/scale_slice_0" - input: "decoder/block_000/layer_002/rms_norm/scale_slice_0/Initializer/random_uniform" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_002/rms_norm/scale_slice_0/read" - op: "Identity" - input: "decoder/block_000/layer_002/rms_norm/scale_slice_0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/rms_norm/scale_1/Cast" - op: "Cast" - input: "decoder/block_000/layer_002/rms_norm/scale_slice_0/read" - attr { - key: "DstT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "SrcT" - value { - type: DT_FLOAT - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/rms_norm/scale_1/parallel_0_1/Cast" - op: "Cast" - input: "decoder/block_000/layer_002/rms_norm/scale/read" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/rms_norm/scale_1/parallel_0_1/Assign" - op: "Assign" - input: "decoder/block_000/layer_002/rms_norm/scale_slice_0" - input: "decoder/block_000/layer_002/rms_norm/scale_1/parallel_0_1/Cast" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_002/rms_norm/scale_1/group_deps" - op: "NoOp" - input: "^decoder/block_000/layer_002/rms_norm/scale_1/parallel_0_1/Assign" -} -node { - name: "decoder/block_000/layer_002/rms_norm/scale_1/Assign" - op: "Assign" - input: "decoder/block_000/layer_002/rms_norm/scale" - input: "decoder/block_000/layer_002/rms_norm/scale_1/Cast" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slice_0/Initializer/random_uniform/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000@\000\000\000" - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slice_0/Initializer/random_uniform/min" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: -0.25 - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slice_0/Initializer/random_uniform/max" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.25 - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slice_0/Initializer/random_uniform/RandomUniform" - op: "RandomUniform" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slice_0/Initializer/random_uniform/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slice_0/Initializer/random_uniform/sub" - op: "Sub" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slice_0/Initializer/random_uniform/max" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slice_0/Initializer/random_uniform/mul" - op: "Mul" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slice_0/Initializer/random_uniform/RandomUniform" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slice_0/Initializer/random_uniform/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slice_0/Initializer/random_uniform" - op: "Add" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slice_0/Initializer/random_uniform/mul" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slice_0" - op: "VariableV2" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slice_0/Assign" - op: "Assign" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slice_0" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slice_0/Initializer/random_uniform" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slice_0/read" - op: "Identity" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slice_0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_1/Cast" - op: "Cast" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slice_0/read" - attr { - key: "DstT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "SrcT" - value { - type: DT_FLOAT - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_1/parallel_0_1/Cast" - op: "Cast" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/read" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_1/parallel_0_1/Assign" - op: "Assign" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slice_0" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_1/parallel_0_1/Cast" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_1/group_deps" - op: "NoOp" - input: "^decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_1/parallel_0_1/Assign" -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_1/Assign" - op: "Assign" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_1/Cast" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wi_0/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slice_0/Initializer/random_uniform/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000@\000\000\000" - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slice_0/Initializer/random_uniform/min" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: -0.25 - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slice_0/Initializer/random_uniform/max" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.25 - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slice_0/Initializer/random_uniform/RandomUniform" - op: "RandomUniform" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slice_0/Initializer/random_uniform/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slice_0/Initializer/random_uniform/sub" - op: "Sub" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slice_0/Initializer/random_uniform/max" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slice_0/Initializer/random_uniform/mul" - op: "Mul" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slice_0/Initializer/random_uniform/RandomUniform" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slice_0/Initializer/random_uniform/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slice_0/Initializer/random_uniform" - op: "Add" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slice_0/Initializer/random_uniform/mul" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slice_0" - op: "VariableV2" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slice_0/Assign" - op: "Assign" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slice_0" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slice_0/Initializer/random_uniform" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slice_0/read" - op: "Identity" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slice_0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_1/Cast" - op: "Cast" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slice_0/read" - attr { - key: "DstT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "SrcT" - value { - type: DT_FLOAT - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_1/parallel_0_1/Cast" - op: "Cast" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/read" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_1/parallel_0_1/Assign" - op: "Assign" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slice_0" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_1/parallel_0_1/Cast" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_1/group_deps" - op: "NoOp" - input: "^decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_1/parallel_0_1/Assign" -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_1/Assign" - op: "Assign" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_1/Cast" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wi_1/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wo/kernel_slice_0/Initializer/random_uniform/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wo/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "@\000\000\000 \000\000\000" - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wo/kernel_slice_0/Initializer/random_uniform/min" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wo/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: -0.25 - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wo/kernel_slice_0/Initializer/random_uniform/max" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wo/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.25 - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wo/kernel_slice_0/Initializer/random_uniform/RandomUniform" - op: "RandomUniform" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel_slice_0/Initializer/random_uniform/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wo/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wo/kernel_slice_0/Initializer/random_uniform/sub" - op: "Sub" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel_slice_0/Initializer/random_uniform/max" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wo/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wo/kernel_slice_0/Initializer/random_uniform/mul" - op: "Mul" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel_slice_0/Initializer/random_uniform/RandomUniform" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel_slice_0/Initializer/random_uniform/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wo/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wo/kernel_slice_0/Initializer/random_uniform" - op: "Add" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel_slice_0/Initializer/random_uniform/mul" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wo/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wo/kernel_slice_0" - op: "VariableV2" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wo/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wo/kernel_slice_0/Assign" - op: "Assign" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel_slice_0" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel_slice_0/Initializer/random_uniform" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wo/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wo/kernel_slice_0/read" - op: "Identity" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel_slice_0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wo/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wo/kernel_1/Cast" - op: "Cast" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel_slice_0/read" - attr { - key: "DstT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "SrcT" - value { - type: DT_FLOAT - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wo/kernel_1/parallel_0_1/Cast" - op: "Cast" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/read" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wo/kernel_1/parallel_0_1/Assign" - op: "Assign" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel_slice_0" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel_1/parallel_0_1/Cast" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wo/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wo/kernel_1/group_deps" - op: "NoOp" - input: "^decoder/block_000/layer_002/DenseReluDense/wo/kernel_1/parallel_0_1/Assign" -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wo/kernel_1/Assign" - op: "Assign" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel_1/Cast" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wo/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_001/layer_000/rms_norm/scale_slice_0/Initializer/random_uniform/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 32 - } - } - } -} -node { - name: "decoder/block_001/layer_000/rms_norm/scale_slice_0/Initializer/random_uniform/min" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: -0.3061862289905548 - } - } - } -} -node { - name: "decoder/block_001/layer_000/rms_norm/scale_slice_0/Initializer/random_uniform/max" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.3061862289905548 - } - } - } -} -node { - name: "decoder/block_001/layer_000/rms_norm/scale_slice_0/Initializer/random_uniform/RandomUniform" - op: "RandomUniform" - input: "decoder/block_001/layer_000/rms_norm/scale_slice_0/Initializer/random_uniform/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "decoder/block_001/layer_000/rms_norm/scale_slice_0/Initializer/random_uniform/sub" - op: "Sub" - input: "decoder/block_001/layer_000/rms_norm/scale_slice_0/Initializer/random_uniform/max" - input: "decoder/block_001/layer_000/rms_norm/scale_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/rms_norm/scale_slice_0/Initializer/random_uniform/mul" - op: "Mul" - input: "decoder/block_001/layer_000/rms_norm/scale_slice_0/Initializer/random_uniform/RandomUniform" - input: "decoder/block_001/layer_000/rms_norm/scale_slice_0/Initializer/random_uniform/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/rms_norm/scale_slice_0/Initializer/random_uniform" - op: "Add" - input: "decoder/block_001/layer_000/rms_norm/scale_slice_0/Initializer/random_uniform/mul" - input: "decoder/block_001/layer_000/rms_norm/scale_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/rms_norm/scale_slice_0" - op: "VariableV2" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_001/layer_000/rms_norm/scale_slice_0/Assign" - op: "Assign" - input: "decoder/block_001/layer_000/rms_norm/scale_slice_0" - input: "decoder/block_001/layer_000/rms_norm/scale_slice_0/Initializer/random_uniform" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_001/layer_000/rms_norm/scale_slice_0/read" - op: "Identity" - input: "decoder/block_001/layer_000/rms_norm/scale_slice_0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/rms_norm/scale_1/Cast" - op: "Cast" - input: "decoder/block_001/layer_000/rms_norm/scale_slice_0/read" - attr { - key: "DstT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "SrcT" - value { - type: DT_FLOAT - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/rms_norm/scale_1/parallel_0_1/Cast" - op: "Cast" - input: "decoder/block_001/layer_000/rms_norm/scale/read" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/rms_norm/scale_1/parallel_0_1/Assign" - op: "Assign" - input: "decoder/block_001/layer_000/rms_norm/scale_slice_0" - input: "decoder/block_001/layer_000/rms_norm/scale_1/parallel_0_1/Cast" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_001/layer_000/rms_norm/scale_1/group_deps" - op: "NoOp" - input: "^decoder/block_001/layer_000/rms_norm/scale_1/parallel_0_1/Assign" -} -node { - name: "decoder/block_001/layer_000/rms_norm/scale_1/Assign" - op: "Assign" - input: "decoder/block_001/layer_000/rms_norm/scale" - input: "decoder/block_001/layer_000/rms_norm/scale_1/Cast" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/q_slice_0/Initializer/random_uniform/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/q_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200\000\000\000" - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/q_slice_0/Initializer/random_uniform/min" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/q_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: -0.19364917278289795 - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/q_slice_0/Initializer/random_uniform/max" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/q_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.19364917278289795 - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/q_slice_0/Initializer/random_uniform/RandomUniform" - op: "RandomUniform" - input: "decoder/block_001/layer_000/SelfAttention/q_slice_0/Initializer/random_uniform/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/q_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/q_slice_0/Initializer/random_uniform/sub" - op: "Sub" - input: "decoder/block_001/layer_000/SelfAttention/q_slice_0/Initializer/random_uniform/max" - input: "decoder/block_001/layer_000/SelfAttention/q_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/q_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/q_slice_0/Initializer/random_uniform/mul" - op: "Mul" - input: "decoder/block_001/layer_000/SelfAttention/q_slice_0/Initializer/random_uniform/RandomUniform" - input: "decoder/block_001/layer_000/SelfAttention/q_slice_0/Initializer/random_uniform/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/q_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/q_slice_0/Initializer/random_uniform" - op: "Add" - input: "decoder/block_001/layer_000/SelfAttention/q_slice_0/Initializer/random_uniform/mul" - input: "decoder/block_001/layer_000/SelfAttention/q_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/q_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/q_slice_0" - op: "VariableV2" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/q_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/q_slice_0/Assign" - op: "Assign" - input: "decoder/block_001/layer_000/SelfAttention/q_slice_0" - input: "decoder/block_001/layer_000/SelfAttention/q_slice_0/Initializer/random_uniform" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/q_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/q_slice_0/read" - op: "Identity" - input: "decoder/block_001/layer_000/SelfAttention/q_slice_0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/q_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/q_1/Cast" - op: "Cast" - input: "decoder/block_001/layer_000/SelfAttention/q_slice_0/read" - attr { - key: "DstT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "SrcT" - value { - type: DT_FLOAT - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/q_1/parallel_0_1/Cast" - op: "Cast" - input: "decoder/block_001/layer_000/SelfAttention/q/read" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/q_1/parallel_0_1/Assign" - op: "Assign" - input: "decoder/block_001/layer_000/SelfAttention/q_slice_0" - input: "decoder/block_001/layer_000/SelfAttention/q_1/parallel_0_1/Cast" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/q_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/q_1/group_deps" - op: "NoOp" - input: "^decoder/block_001/layer_000/SelfAttention/q_1/parallel_0_1/Assign" -} -node { - name: "decoder/block_001/layer_000/SelfAttention/q_1/Assign" - op: "Assign" - input: "decoder/block_001/layer_000/SelfAttention/q" - input: "decoder/block_001/layer_000/SelfAttention/q_1/Cast" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/k_slice_0/Initializer/random_uniform/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/k_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200\000\000\000" - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/k_slice_0/Initializer/random_uniform/min" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/k_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: -0.19364917278289795 - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/k_slice_0/Initializer/random_uniform/max" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/k_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.19364917278289795 - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/k_slice_0/Initializer/random_uniform/RandomUniform" - op: "RandomUniform" - input: "decoder/block_001/layer_000/SelfAttention/k_slice_0/Initializer/random_uniform/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/k_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/k_slice_0/Initializer/random_uniform/sub" - op: "Sub" - input: "decoder/block_001/layer_000/SelfAttention/k_slice_0/Initializer/random_uniform/max" - input: "decoder/block_001/layer_000/SelfAttention/k_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/k_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/k_slice_0/Initializer/random_uniform/mul" - op: "Mul" - input: "decoder/block_001/layer_000/SelfAttention/k_slice_0/Initializer/random_uniform/RandomUniform" - input: "decoder/block_001/layer_000/SelfAttention/k_slice_0/Initializer/random_uniform/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/k_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/k_slice_0/Initializer/random_uniform" - op: "Add" - input: "decoder/block_001/layer_000/SelfAttention/k_slice_0/Initializer/random_uniform/mul" - input: "decoder/block_001/layer_000/SelfAttention/k_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/k_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/k_slice_0" - op: "VariableV2" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/k_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/k_slice_0/Assign" - op: "Assign" - input: "decoder/block_001/layer_000/SelfAttention/k_slice_0" - input: "decoder/block_001/layer_000/SelfAttention/k_slice_0/Initializer/random_uniform" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/k_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/k_slice_0/read" - op: "Identity" - input: "decoder/block_001/layer_000/SelfAttention/k_slice_0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/k_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/k_1/Cast" - op: "Cast" - input: "decoder/block_001/layer_000/SelfAttention/k_slice_0/read" - attr { - key: "DstT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "SrcT" - value { - type: DT_FLOAT - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/k_1/parallel_0_1/Cast" - op: "Cast" - input: "decoder/block_001/layer_000/SelfAttention/k/read" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/k_1/parallel_0_1/Assign" - op: "Assign" - input: "decoder/block_001/layer_000/SelfAttention/k_slice_0" - input: "decoder/block_001/layer_000/SelfAttention/k_1/parallel_0_1/Cast" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/k_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/k_1/group_deps" - op: "NoOp" - input: "^decoder/block_001/layer_000/SelfAttention/k_1/parallel_0_1/Assign" -} -node { - name: "decoder/block_001/layer_000/SelfAttention/k_1/Assign" - op: "Assign" - input: "decoder/block_001/layer_000/SelfAttention/k" - input: "decoder/block_001/layer_000/SelfAttention/k_1/Cast" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/v_slice_0/Initializer/random_uniform/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/v_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200\000\000\000" - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/v_slice_0/Initializer/random_uniform/min" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/v_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: -0.19364917278289795 - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/v_slice_0/Initializer/random_uniform/max" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/v_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.19364917278289795 - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/v_slice_0/Initializer/random_uniform/RandomUniform" - op: "RandomUniform" - input: "decoder/block_001/layer_000/SelfAttention/v_slice_0/Initializer/random_uniform/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/v_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/v_slice_0/Initializer/random_uniform/sub" - op: "Sub" - input: "decoder/block_001/layer_000/SelfAttention/v_slice_0/Initializer/random_uniform/max" - input: "decoder/block_001/layer_000/SelfAttention/v_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/v_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/v_slice_0/Initializer/random_uniform/mul" - op: "Mul" - input: "decoder/block_001/layer_000/SelfAttention/v_slice_0/Initializer/random_uniform/RandomUniform" - input: "decoder/block_001/layer_000/SelfAttention/v_slice_0/Initializer/random_uniform/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/v_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/v_slice_0/Initializer/random_uniform" - op: "Add" - input: "decoder/block_001/layer_000/SelfAttention/v_slice_0/Initializer/random_uniform/mul" - input: "decoder/block_001/layer_000/SelfAttention/v_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/v_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/v_slice_0" - op: "VariableV2" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/v_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/v_slice_0/Assign" - op: "Assign" - input: "decoder/block_001/layer_000/SelfAttention/v_slice_0" - input: "decoder/block_001/layer_000/SelfAttention/v_slice_0/Initializer/random_uniform" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/v_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/v_slice_0/read" - op: "Identity" - input: "decoder/block_001/layer_000/SelfAttention/v_slice_0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/v_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/v_1/Cast" - op: "Cast" - input: "decoder/block_001/layer_000/SelfAttention/v_slice_0/read" - attr { - key: "DstT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "SrcT" - value { - type: DT_FLOAT - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/v_1/parallel_0_1/Cast" - op: "Cast" - input: "decoder/block_001/layer_000/SelfAttention/v/read" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/v_1/parallel_0_1/Assign" - op: "Assign" - input: "decoder/block_001/layer_000/SelfAttention/v_slice_0" - input: "decoder/block_001/layer_000/SelfAttention/v_1/parallel_0_1/Cast" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/v_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/v_1/group_deps" - op: "NoOp" - input: "^decoder/block_001/layer_000/SelfAttention/v_1/parallel_0_1/Assign" -} -node { - name: "decoder/block_001/layer_000/SelfAttention/v_1/Assign" - op: "Assign" - input: "decoder/block_001/layer_000/SelfAttention/v" - input: "decoder/block_001/layer_000/SelfAttention/v_1/Cast" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/o_slice_0/Initializer/random_uniform/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/o_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\200\000\000\000 \000\000\000" - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/o_slice_0/Initializer/random_uniform/min" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/o_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: -0.19364917278289795 - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/o_slice_0/Initializer/random_uniform/max" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/o_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.19364917278289795 - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/o_slice_0/Initializer/random_uniform/RandomUniform" - op: "RandomUniform" - input: "decoder/block_001/layer_000/SelfAttention/o_slice_0/Initializer/random_uniform/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/o_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/o_slice_0/Initializer/random_uniform/sub" - op: "Sub" - input: "decoder/block_001/layer_000/SelfAttention/o_slice_0/Initializer/random_uniform/max" - input: "decoder/block_001/layer_000/SelfAttention/o_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/o_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/o_slice_0/Initializer/random_uniform/mul" - op: "Mul" - input: "decoder/block_001/layer_000/SelfAttention/o_slice_0/Initializer/random_uniform/RandomUniform" - input: "decoder/block_001/layer_000/SelfAttention/o_slice_0/Initializer/random_uniform/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/o_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/o_slice_0/Initializer/random_uniform" - op: "Add" - input: "decoder/block_001/layer_000/SelfAttention/o_slice_0/Initializer/random_uniform/mul" - input: "decoder/block_001/layer_000/SelfAttention/o_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/o_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/o_slice_0" - op: "VariableV2" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/o_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/o_slice_0/Assign" - op: "Assign" - input: "decoder/block_001/layer_000/SelfAttention/o_slice_0" - input: "decoder/block_001/layer_000/SelfAttention/o_slice_0/Initializer/random_uniform" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/o_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/o_slice_0/read" - op: "Identity" - input: "decoder/block_001/layer_000/SelfAttention/o_slice_0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/o_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/o_1/Cast" - op: "Cast" - input: "decoder/block_001/layer_000/SelfAttention/o_slice_0/read" - attr { - key: "DstT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "SrcT" - value { - type: DT_FLOAT - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/o_1/parallel_0_1/Cast" - op: "Cast" - input: "decoder/block_001/layer_000/SelfAttention/o/read" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/o_1/parallel_0_1/Assign" - op: "Assign" - input: "decoder/block_001/layer_000/SelfAttention/o_slice_0" - input: "decoder/block_001/layer_000/SelfAttention/o_1/parallel_0_1/Cast" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/o_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/o_1/group_deps" - op: "NoOp" - input: "^decoder/block_001/layer_000/SelfAttention/o_1/parallel_0_1/Assign" -} -node { - name: "decoder/block_001/layer_000/SelfAttention/o_1/Assign" - op: "Assign" - input: "decoder/block_001/layer_000/SelfAttention/o" - input: "decoder/block_001/layer_000/SelfAttention/o_1/Cast" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_001/layer_001/rms_norm/scale_slice_0/Initializer/random_uniform/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 32 - } - } - } -} -node { - name: "decoder/block_001/layer_001/rms_norm/scale_slice_0/Initializer/random_uniform/min" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: -0.3061862289905548 - } - } - } -} -node { - name: "decoder/block_001/layer_001/rms_norm/scale_slice_0/Initializer/random_uniform/max" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.3061862289905548 - } - } - } -} -node { - name: "decoder/block_001/layer_001/rms_norm/scale_slice_0/Initializer/random_uniform/RandomUniform" - op: "RandomUniform" - input: "decoder/block_001/layer_001/rms_norm/scale_slice_0/Initializer/random_uniform/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "decoder/block_001/layer_001/rms_norm/scale_slice_0/Initializer/random_uniform/sub" - op: "Sub" - input: "decoder/block_001/layer_001/rms_norm/scale_slice_0/Initializer/random_uniform/max" - input: "decoder/block_001/layer_001/rms_norm/scale_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/rms_norm/scale_slice_0/Initializer/random_uniform/mul" - op: "Mul" - input: "decoder/block_001/layer_001/rms_norm/scale_slice_0/Initializer/random_uniform/RandomUniform" - input: "decoder/block_001/layer_001/rms_norm/scale_slice_0/Initializer/random_uniform/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/rms_norm/scale_slice_0/Initializer/random_uniform" - op: "Add" - input: "decoder/block_001/layer_001/rms_norm/scale_slice_0/Initializer/random_uniform/mul" - input: "decoder/block_001/layer_001/rms_norm/scale_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/rms_norm/scale_slice_0" - op: "VariableV2" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_001/layer_001/rms_norm/scale_slice_0/Assign" - op: "Assign" - input: "decoder/block_001/layer_001/rms_norm/scale_slice_0" - input: "decoder/block_001/layer_001/rms_norm/scale_slice_0/Initializer/random_uniform" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_001/layer_001/rms_norm/scale_slice_0/read" - op: "Identity" - input: "decoder/block_001/layer_001/rms_norm/scale_slice_0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/rms_norm/scale_1/Cast" - op: "Cast" - input: "decoder/block_001/layer_001/rms_norm/scale_slice_0/read" - attr { - key: "DstT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "SrcT" - value { - type: DT_FLOAT - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/rms_norm/scale_1/parallel_0_1/Cast" - op: "Cast" - input: "decoder/block_001/layer_001/rms_norm/scale/read" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/rms_norm/scale_1/parallel_0_1/Assign" - op: "Assign" - input: "decoder/block_001/layer_001/rms_norm/scale_slice_0" - input: "decoder/block_001/layer_001/rms_norm/scale_1/parallel_0_1/Cast" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_001/layer_001/rms_norm/scale_1/group_deps" - op: "NoOp" - input: "^decoder/block_001/layer_001/rms_norm/scale_1/parallel_0_1/Assign" -} -node { - name: "decoder/block_001/layer_001/rms_norm/scale_1/Assign" - op: "Assign" - input: "decoder/block_001/layer_001/rms_norm/scale" - input: "decoder/block_001/layer_001/rms_norm/scale_1/Cast" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/q_slice_0/Initializer/random_uniform/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/q_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200\000\000\000" - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/q_slice_0/Initializer/random_uniform/min" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/q_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: -0.19364917278289795 - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/q_slice_0/Initializer/random_uniform/max" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/q_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.19364917278289795 - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/q_slice_0/Initializer/random_uniform/RandomUniform" - op: "RandomUniform" - input: "decoder/block_001/layer_001/EncDecAttention/q_slice_0/Initializer/random_uniform/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/q_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/q_slice_0/Initializer/random_uniform/sub" - op: "Sub" - input: "decoder/block_001/layer_001/EncDecAttention/q_slice_0/Initializer/random_uniform/max" - input: "decoder/block_001/layer_001/EncDecAttention/q_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/q_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/q_slice_0/Initializer/random_uniform/mul" - op: "Mul" - input: "decoder/block_001/layer_001/EncDecAttention/q_slice_0/Initializer/random_uniform/RandomUniform" - input: "decoder/block_001/layer_001/EncDecAttention/q_slice_0/Initializer/random_uniform/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/q_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/q_slice_0/Initializer/random_uniform" - op: "Add" - input: "decoder/block_001/layer_001/EncDecAttention/q_slice_0/Initializer/random_uniform/mul" - input: "decoder/block_001/layer_001/EncDecAttention/q_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/q_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/q_slice_0" - op: "VariableV2" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/q_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/q_slice_0/Assign" - op: "Assign" - input: "decoder/block_001/layer_001/EncDecAttention/q_slice_0" - input: "decoder/block_001/layer_001/EncDecAttention/q_slice_0/Initializer/random_uniform" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/q_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/q_slice_0/read" - op: "Identity" - input: "decoder/block_001/layer_001/EncDecAttention/q_slice_0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/q_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/q_1/Cast" - op: "Cast" - input: "decoder/block_001/layer_001/EncDecAttention/q_slice_0/read" - attr { - key: "DstT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "SrcT" - value { - type: DT_FLOAT - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/q_1/parallel_0_1/Cast" - op: "Cast" - input: "decoder/block_001/layer_001/EncDecAttention/q/read" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/q_1/parallel_0_1/Assign" - op: "Assign" - input: "decoder/block_001/layer_001/EncDecAttention/q_slice_0" - input: "decoder/block_001/layer_001/EncDecAttention/q_1/parallel_0_1/Cast" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/q_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/q_1/group_deps" - op: "NoOp" - input: "^decoder/block_001/layer_001/EncDecAttention/q_1/parallel_0_1/Assign" -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/q_1/Assign" - op: "Assign" - input: "decoder/block_001/layer_001/EncDecAttention/q" - input: "decoder/block_001/layer_001/EncDecAttention/q_1/Cast" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/k_slice_0/Initializer/random_uniform/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/k_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200\000\000\000" - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/k_slice_0/Initializer/random_uniform/min" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/k_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: -0.19364917278289795 - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/k_slice_0/Initializer/random_uniform/max" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/k_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.19364917278289795 - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/k_slice_0/Initializer/random_uniform/RandomUniform" - op: "RandomUniform" - input: "decoder/block_001/layer_001/EncDecAttention/k_slice_0/Initializer/random_uniform/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/k_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/k_slice_0/Initializer/random_uniform/sub" - op: "Sub" - input: "decoder/block_001/layer_001/EncDecAttention/k_slice_0/Initializer/random_uniform/max" - input: "decoder/block_001/layer_001/EncDecAttention/k_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/k_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/k_slice_0/Initializer/random_uniform/mul" - op: "Mul" - input: "decoder/block_001/layer_001/EncDecAttention/k_slice_0/Initializer/random_uniform/RandomUniform" - input: "decoder/block_001/layer_001/EncDecAttention/k_slice_0/Initializer/random_uniform/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/k_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/k_slice_0/Initializer/random_uniform" - op: "Add" - input: "decoder/block_001/layer_001/EncDecAttention/k_slice_0/Initializer/random_uniform/mul" - input: "decoder/block_001/layer_001/EncDecAttention/k_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/k_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/k_slice_0" - op: "VariableV2" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/k_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/k_slice_0/Assign" - op: "Assign" - input: "decoder/block_001/layer_001/EncDecAttention/k_slice_0" - input: "decoder/block_001/layer_001/EncDecAttention/k_slice_0/Initializer/random_uniform" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/k_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/k_slice_0/read" - op: "Identity" - input: "decoder/block_001/layer_001/EncDecAttention/k_slice_0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/k_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/k_1/Cast" - op: "Cast" - input: "decoder/block_001/layer_001/EncDecAttention/k_slice_0/read" - attr { - key: "DstT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "SrcT" - value { - type: DT_FLOAT - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/k_1/parallel_0_1/Cast" - op: "Cast" - input: "decoder/block_001/layer_001/EncDecAttention/k/read" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/k_1/parallel_0_1/Assign" - op: "Assign" - input: "decoder/block_001/layer_001/EncDecAttention/k_slice_0" - input: "decoder/block_001/layer_001/EncDecAttention/k_1/parallel_0_1/Cast" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/k_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/k_1/group_deps" - op: "NoOp" - input: "^decoder/block_001/layer_001/EncDecAttention/k_1/parallel_0_1/Assign" -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/k_1/Assign" - op: "Assign" - input: "decoder/block_001/layer_001/EncDecAttention/k" - input: "decoder/block_001/layer_001/EncDecAttention/k_1/Cast" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/v_slice_0/Initializer/random_uniform/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/v_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200\000\000\000" - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/v_slice_0/Initializer/random_uniform/min" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/v_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: -0.19364917278289795 - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/v_slice_0/Initializer/random_uniform/max" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/v_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.19364917278289795 - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/v_slice_0/Initializer/random_uniform/RandomUniform" - op: "RandomUniform" - input: "decoder/block_001/layer_001/EncDecAttention/v_slice_0/Initializer/random_uniform/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/v_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/v_slice_0/Initializer/random_uniform/sub" - op: "Sub" - input: "decoder/block_001/layer_001/EncDecAttention/v_slice_0/Initializer/random_uniform/max" - input: "decoder/block_001/layer_001/EncDecAttention/v_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/v_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/v_slice_0/Initializer/random_uniform/mul" - op: "Mul" - input: "decoder/block_001/layer_001/EncDecAttention/v_slice_0/Initializer/random_uniform/RandomUniform" - input: "decoder/block_001/layer_001/EncDecAttention/v_slice_0/Initializer/random_uniform/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/v_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/v_slice_0/Initializer/random_uniform" - op: "Add" - input: "decoder/block_001/layer_001/EncDecAttention/v_slice_0/Initializer/random_uniform/mul" - input: "decoder/block_001/layer_001/EncDecAttention/v_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/v_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/v_slice_0" - op: "VariableV2" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/v_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/v_slice_0/Assign" - op: "Assign" - input: "decoder/block_001/layer_001/EncDecAttention/v_slice_0" - input: "decoder/block_001/layer_001/EncDecAttention/v_slice_0/Initializer/random_uniform" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/v_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/v_slice_0/read" - op: "Identity" - input: "decoder/block_001/layer_001/EncDecAttention/v_slice_0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/v_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/v_1/Cast" - op: "Cast" - input: "decoder/block_001/layer_001/EncDecAttention/v_slice_0/read" - attr { - key: "DstT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "SrcT" - value { - type: DT_FLOAT - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/v_1/parallel_0_1/Cast" - op: "Cast" - input: "decoder/block_001/layer_001/EncDecAttention/v/read" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/v_1/parallel_0_1/Assign" - op: "Assign" - input: "decoder/block_001/layer_001/EncDecAttention/v_slice_0" - input: "decoder/block_001/layer_001/EncDecAttention/v_1/parallel_0_1/Cast" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/v_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/v_1/group_deps" - op: "NoOp" - input: "^decoder/block_001/layer_001/EncDecAttention/v_1/parallel_0_1/Assign" -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/v_1/Assign" - op: "Assign" - input: "decoder/block_001/layer_001/EncDecAttention/v" - input: "decoder/block_001/layer_001/EncDecAttention/v_1/Cast" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/o_slice_0/Initializer/random_uniform/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/o_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\200\000\000\000 \000\000\000" - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/o_slice_0/Initializer/random_uniform/min" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/o_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: -0.19364917278289795 - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/o_slice_0/Initializer/random_uniform/max" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/o_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.19364917278289795 - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/o_slice_0/Initializer/random_uniform/RandomUniform" - op: "RandomUniform" - input: "decoder/block_001/layer_001/EncDecAttention/o_slice_0/Initializer/random_uniform/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/o_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/o_slice_0/Initializer/random_uniform/sub" - op: "Sub" - input: "decoder/block_001/layer_001/EncDecAttention/o_slice_0/Initializer/random_uniform/max" - input: "decoder/block_001/layer_001/EncDecAttention/o_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/o_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/o_slice_0/Initializer/random_uniform/mul" - op: "Mul" - input: "decoder/block_001/layer_001/EncDecAttention/o_slice_0/Initializer/random_uniform/RandomUniform" - input: "decoder/block_001/layer_001/EncDecAttention/o_slice_0/Initializer/random_uniform/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/o_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/o_slice_0/Initializer/random_uniform" - op: "Add" - input: "decoder/block_001/layer_001/EncDecAttention/o_slice_0/Initializer/random_uniform/mul" - input: "decoder/block_001/layer_001/EncDecAttention/o_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/o_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/o_slice_0" - op: "VariableV2" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/o_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/o_slice_0/Assign" - op: "Assign" - input: "decoder/block_001/layer_001/EncDecAttention/o_slice_0" - input: "decoder/block_001/layer_001/EncDecAttention/o_slice_0/Initializer/random_uniform" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/o_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/o_slice_0/read" - op: "Identity" - input: "decoder/block_001/layer_001/EncDecAttention/o_slice_0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/o_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/o_1/Cast" - op: "Cast" - input: "decoder/block_001/layer_001/EncDecAttention/o_slice_0/read" - attr { - key: "DstT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "SrcT" - value { - type: DT_FLOAT - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/o_1/parallel_0_1/Cast" - op: "Cast" - input: "decoder/block_001/layer_001/EncDecAttention/o/read" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/o_1/parallel_0_1/Assign" - op: "Assign" - input: "decoder/block_001/layer_001/EncDecAttention/o_slice_0" - input: "decoder/block_001/layer_001/EncDecAttention/o_1/parallel_0_1/Cast" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/o_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/o_1/group_deps" - op: "NoOp" - input: "^decoder/block_001/layer_001/EncDecAttention/o_1/parallel_0_1/Assign" -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/o_1/Assign" - op: "Assign" - input: "decoder/block_001/layer_001/EncDecAttention/o" - input: "decoder/block_001/layer_001/EncDecAttention/o_1/Cast" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_001/layer_002/rms_norm/scale_slice_0/Initializer/random_uniform/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 32 - } - } - } -} -node { - name: "decoder/block_001/layer_002/rms_norm/scale_slice_0/Initializer/random_uniform/min" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: -0.3061862289905548 - } - } - } -} -node { - name: "decoder/block_001/layer_002/rms_norm/scale_slice_0/Initializer/random_uniform/max" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.3061862289905548 - } - } - } -} -node { - name: "decoder/block_001/layer_002/rms_norm/scale_slice_0/Initializer/random_uniform/RandomUniform" - op: "RandomUniform" - input: "decoder/block_001/layer_002/rms_norm/scale_slice_0/Initializer/random_uniform/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "decoder/block_001/layer_002/rms_norm/scale_slice_0/Initializer/random_uniform/sub" - op: "Sub" - input: "decoder/block_001/layer_002/rms_norm/scale_slice_0/Initializer/random_uniform/max" - input: "decoder/block_001/layer_002/rms_norm/scale_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/rms_norm/scale_slice_0/Initializer/random_uniform/mul" - op: "Mul" - input: "decoder/block_001/layer_002/rms_norm/scale_slice_0/Initializer/random_uniform/RandomUniform" - input: "decoder/block_001/layer_002/rms_norm/scale_slice_0/Initializer/random_uniform/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/rms_norm/scale_slice_0/Initializer/random_uniform" - op: "Add" - input: "decoder/block_001/layer_002/rms_norm/scale_slice_0/Initializer/random_uniform/mul" - input: "decoder/block_001/layer_002/rms_norm/scale_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/rms_norm/scale_slice_0" - op: "VariableV2" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_001/layer_002/rms_norm/scale_slice_0/Assign" - op: "Assign" - input: "decoder/block_001/layer_002/rms_norm/scale_slice_0" - input: "decoder/block_001/layer_002/rms_norm/scale_slice_0/Initializer/random_uniform" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_001/layer_002/rms_norm/scale_slice_0/read" - op: "Identity" - input: "decoder/block_001/layer_002/rms_norm/scale_slice_0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/rms_norm/scale_1/Cast" - op: "Cast" - input: "decoder/block_001/layer_002/rms_norm/scale_slice_0/read" - attr { - key: "DstT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "SrcT" - value { - type: DT_FLOAT - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/rms_norm/scale_1/parallel_0_1/Cast" - op: "Cast" - input: "decoder/block_001/layer_002/rms_norm/scale/read" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/rms_norm/scale_1/parallel_0_1/Assign" - op: "Assign" - input: "decoder/block_001/layer_002/rms_norm/scale_slice_0" - input: "decoder/block_001/layer_002/rms_norm/scale_1/parallel_0_1/Cast" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_001/layer_002/rms_norm/scale_1/group_deps" - op: "NoOp" - input: "^decoder/block_001/layer_002/rms_norm/scale_1/parallel_0_1/Assign" -} -node { - name: "decoder/block_001/layer_002/rms_norm/scale_1/Assign" - op: "Assign" - input: "decoder/block_001/layer_002/rms_norm/scale" - input: "decoder/block_001/layer_002/rms_norm/scale_1/Cast" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slice_0/Initializer/random_uniform/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000@\000\000\000" - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slice_0/Initializer/random_uniform/min" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: -0.25 - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slice_0/Initializer/random_uniform/max" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.25 - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slice_0/Initializer/random_uniform/RandomUniform" - op: "RandomUniform" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slice_0/Initializer/random_uniform/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slice_0/Initializer/random_uniform/sub" - op: "Sub" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slice_0/Initializer/random_uniform/max" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slice_0/Initializer/random_uniform/mul" - op: "Mul" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slice_0/Initializer/random_uniform/RandomUniform" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slice_0/Initializer/random_uniform/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slice_0/Initializer/random_uniform" - op: "Add" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slice_0/Initializer/random_uniform/mul" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slice_0" - op: "VariableV2" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slice_0/Assign" - op: "Assign" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slice_0" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slice_0/Initializer/random_uniform" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slice_0/read" - op: "Identity" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slice_0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_1/Cast" - op: "Cast" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slice_0/read" - attr { - key: "DstT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "SrcT" - value { - type: DT_FLOAT - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_1/parallel_0_1/Cast" - op: "Cast" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/read" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_1/parallel_0_1/Assign" - op: "Assign" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slice_0" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_1/parallel_0_1/Cast" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_1/group_deps" - op: "NoOp" - input: "^decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_1/parallel_0_1/Assign" -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_1/Assign" - op: "Assign" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_1/Cast" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wi_0/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slice_0/Initializer/random_uniform/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000@\000\000\000" - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slice_0/Initializer/random_uniform/min" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: -0.25 - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slice_0/Initializer/random_uniform/max" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.25 - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slice_0/Initializer/random_uniform/RandomUniform" - op: "RandomUniform" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slice_0/Initializer/random_uniform/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slice_0/Initializer/random_uniform/sub" - op: "Sub" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slice_0/Initializer/random_uniform/max" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slice_0/Initializer/random_uniform/mul" - op: "Mul" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slice_0/Initializer/random_uniform/RandomUniform" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slice_0/Initializer/random_uniform/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slice_0/Initializer/random_uniform" - op: "Add" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slice_0/Initializer/random_uniform/mul" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slice_0" - op: "VariableV2" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slice_0/Assign" - op: "Assign" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slice_0" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slice_0/Initializer/random_uniform" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slice_0/read" - op: "Identity" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slice_0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_1/Cast" - op: "Cast" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slice_0/read" - attr { - key: "DstT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "SrcT" - value { - type: DT_FLOAT - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_1/parallel_0_1/Cast" - op: "Cast" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/read" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_1/parallel_0_1/Assign" - op: "Assign" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slice_0" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_1/parallel_0_1/Cast" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_1/group_deps" - op: "NoOp" - input: "^decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_1/parallel_0_1/Assign" -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_1/Assign" - op: "Assign" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_1/Cast" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wi_1/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wo/kernel_slice_0/Initializer/random_uniform/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wo/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "@\000\000\000 \000\000\000" - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wo/kernel_slice_0/Initializer/random_uniform/min" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wo/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: -0.25 - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wo/kernel_slice_0/Initializer/random_uniform/max" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wo/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.25 - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wo/kernel_slice_0/Initializer/random_uniform/RandomUniform" - op: "RandomUniform" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel_slice_0/Initializer/random_uniform/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wo/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wo/kernel_slice_0/Initializer/random_uniform/sub" - op: "Sub" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel_slice_0/Initializer/random_uniform/max" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wo/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wo/kernel_slice_0/Initializer/random_uniform/mul" - op: "Mul" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel_slice_0/Initializer/random_uniform/RandomUniform" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel_slice_0/Initializer/random_uniform/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wo/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wo/kernel_slice_0/Initializer/random_uniform" - op: "Add" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel_slice_0/Initializer/random_uniform/mul" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wo/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wo/kernel_slice_0" - op: "VariableV2" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wo/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wo/kernel_slice_0/Assign" - op: "Assign" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel_slice_0" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel_slice_0/Initializer/random_uniform" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wo/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wo/kernel_slice_0/read" - op: "Identity" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel_slice_0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wo/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wo/kernel_1/Cast" - op: "Cast" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel_slice_0/read" - attr { - key: "DstT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "SrcT" - value { - type: DT_FLOAT - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wo/kernel_1/parallel_0_1/Cast" - op: "Cast" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/read" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wo/kernel_1/parallel_0_1/Assign" - op: "Assign" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel_slice_0" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel_1/parallel_0_1/Cast" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wo/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wo/kernel_1/group_deps" - op: "NoOp" - input: "^decoder/block_001/layer_002/DenseReluDense/wo/kernel_1/parallel_0_1/Assign" -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wo/kernel_1/Assign" - op: "Assign" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel_1/Cast" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wo/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/rms_norm/scale_slice_0/Initializer/random_uniform/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 32 - } - } - } -} -node { - name: "decoder/rms_norm/scale_slice_0/Initializer/random_uniform/min" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: -0.3061862289905548 - } - } - } -} -node { - name: "decoder/rms_norm/scale_slice_0/Initializer/random_uniform/max" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.3061862289905548 - } - } - } -} -node { - name: "decoder/rms_norm/scale_slice_0/Initializer/random_uniform/RandomUniform" - op: "RandomUniform" - input: "decoder/rms_norm/scale_slice_0/Initializer/random_uniform/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "decoder/rms_norm/scale_slice_0/Initializer/random_uniform/sub" - op: "Sub" - input: "decoder/rms_norm/scale_slice_0/Initializer/random_uniform/max" - input: "decoder/rms_norm/scale_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/rms_norm/scale_slice_0/Initializer/random_uniform/mul" - op: "Mul" - input: "decoder/rms_norm/scale_slice_0/Initializer/random_uniform/RandomUniform" - input: "decoder/rms_norm/scale_slice_0/Initializer/random_uniform/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/rms_norm/scale_slice_0/Initializer/random_uniform" - op: "Add" - input: "decoder/rms_norm/scale_slice_0/Initializer/random_uniform/mul" - input: "decoder/rms_norm/scale_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/rms_norm/scale_slice_0" - op: "VariableV2" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/rms_norm/scale_slice_0/Assign" - op: "Assign" - input: "decoder/rms_norm/scale_slice_0" - input: "decoder/rms_norm/scale_slice_0/Initializer/random_uniform" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/rms_norm/scale_slice_0/read" - op: "Identity" - input: "decoder/rms_norm/scale_slice_0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/rms_norm/scale_1/Cast" - op: "Cast" - input: "decoder/rms_norm/scale_slice_0/read" - attr { - key: "DstT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "SrcT" - value { - type: DT_FLOAT - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/rms_norm/scale_1/parallel_0_1/Cast" - op: "Cast" - input: "decoder/rms_norm/scale/read" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/rms_norm/scale_1/parallel_0_1/Assign" - op: "Assign" - input: "decoder/rms_norm/scale_slice_0" - input: "decoder/rms_norm/scale_1/parallel_0_1/Cast" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/rms_norm/scale_1/group_deps" - op: "NoOp" - input: "^decoder/rms_norm/scale_1/parallel_0_1/Assign" -} -node { - name: "decoder/rms_norm/scale_1/Assign" - op: "Assign" - input: "decoder/rms_norm/scale" - input: "decoder/rms_norm/scale_1/Cast" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/logits/kernel_slice_0/Initializer/random_uniform/shape" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/logits/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200}\000\000" - } - } - } -} -node { - name: "decoder/logits/kernel_slice_0/Initializer/random_uniform/min" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/logits/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: -0.013658959418535233 - } - } - } -} -node { - name: "decoder/logits/kernel_slice_0/Initializer/random_uniform/max" - op: "Const" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/logits/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.013658959418535233 - } - } - } -} -node { - name: "decoder/logits/kernel_slice_0/Initializer/random_uniform/RandomUniform" - op: "RandomUniform" - input: "decoder/logits/kernel_slice_0/Initializer/random_uniform/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/logits/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 32128 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } -} -node { - name: "decoder/logits/kernel_slice_0/Initializer/random_uniform/sub" - op: "Sub" - input: "decoder/logits/kernel_slice_0/Initializer/random_uniform/max" - input: "decoder/logits/kernel_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/logits/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/logits/kernel_slice_0/Initializer/random_uniform/mul" - op: "Mul" - input: "decoder/logits/kernel_slice_0/Initializer/random_uniform/RandomUniform" - input: "decoder/logits/kernel_slice_0/Initializer/random_uniform/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/logits/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 32128 - } - } - } - } - } -} -node { - name: "decoder/logits/kernel_slice_0/Initializer/random_uniform" - op: "Add" - input: "decoder/logits/kernel_slice_0/Initializer/random_uniform/mul" - input: "decoder/logits/kernel_slice_0/Initializer/random_uniform/min" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/logits/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 32128 - } - } - } - } - } -} -node { - name: "decoder/logits/kernel_slice_0" - op: "VariableV2" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/logits/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 32128 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 32 - } - dim { - size: 32128 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "decoder/logits/kernel_slice_0/Assign" - op: "Assign" - input: "decoder/logits/kernel_slice_0" - input: "decoder/logits/kernel_slice_0/Initializer/random_uniform" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/logits/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 32128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/logits/kernel_slice_0/read" - op: "Identity" - input: "decoder/logits/kernel_slice_0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/logits/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 32128 - } - } - } - } - } -} -node { - name: "decoder/logits/kernel_1/Cast" - op: "Cast" - input: "decoder/logits/kernel_slice_0/read" - attr { - key: "DstT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "SrcT" - value { - type: DT_FLOAT - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 32128 - } - } - } - } - } -} -node { - name: "decoder/logits/kernel_1/parallel_0_1/Cast" - op: "Cast" - input: "decoder/logits/kernel/read" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 32128 - } - } - } - } - } -} -node { - name: "decoder/logits/kernel_1/parallel_0_1/Assign" - op: "Assign" - input: "decoder/logits/kernel_slice_0" - input: "decoder/logits/kernel_1/parallel_0_1/Cast" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/logits/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 32128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "decoder/logits/kernel_1/group_deps" - op: "NoOp" - input: "^decoder/logits/kernel_1/parallel_0_1/Assign" -} -node { - name: "decoder/logits/kernel_1/Assign" - op: "Assign" - input: "decoder/logits/kernel" - input: "decoder/logits/kernel_1/Cast" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/logits/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 32128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "while_loop/parallel_0/zeros" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/parallel_0_1/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\200}\000\000 \000\000\000" - } - } - } -} -node { - name: "while_loop/parallel_0_1/zeros/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/parallel_0_1/zeros" - op: "Fill" - input: "while_loop/parallel_0_1/zeros/shape_as_tensor" - input: "while_loop/parallel_0_1/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "while_loop/parallel_0_2/zeros" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - dim { - size: 32 - } - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/parallel_0_3/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200\000\000\000" - } - } - } -} -node { - name: "while_loop/parallel_0_3/zeros/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/parallel_0_3/zeros" - op: "Fill" - input: "while_loop/parallel_0_3/zeros/shape_as_tensor" - input: "while_loop/parallel_0_3/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "while_loop/parallel_0_4/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200\000\000\000" - } - } - } -} -node { - name: "while_loop/parallel_0_4/zeros/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/parallel_0_4/zeros" - op: "Fill" - input: "while_loop/parallel_0_4/zeros/shape_as_tensor" - input: "while_loop/parallel_0_4/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "while_loop/parallel_0_5/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200\000\000\000" - } - } - } -} -node { - name: "while_loop/parallel_0_5/zeros/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/parallel_0_5/zeros" - op: "Fill" - input: "while_loop/parallel_0_5/zeros/shape_as_tensor" - input: "while_loop/parallel_0_5/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "while_loop/parallel_0_6/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\200\000\000\000 \000\000\000" - } - } - } -} -node { - name: "while_loop/parallel_0_6/zeros/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/parallel_0_6/zeros" - op: "Fill" - input: "while_loop/parallel_0_6/zeros/shape_as_tensor" - input: "while_loop/parallel_0_6/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "while_loop/parallel_0_7/zeros" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/parallel_0_8/zeros" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - dim { - size: 32 - } - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/parallel_0_9/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000@\000\000\000" - } - } - } -} -node { - name: "while_loop/parallel_0_9/zeros/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/parallel_0_9/zeros" - op: "Fill" - input: "while_loop/parallel_0_9/zeros/shape_as_tensor" - input: "while_loop/parallel_0_9/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "while_loop/parallel_0_10/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000@\000\000\000" - } - } - } -} -node { - name: "while_loop/parallel_0_10/zeros/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/parallel_0_10/zeros" - op: "Fill" - input: "while_loop/parallel_0_10/zeros/shape_as_tensor" - input: "while_loop/parallel_0_10/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "while_loop/parallel_0_11/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "@\000\000\000 \000\000\000" - } - } - } -} -node { - name: "while_loop/parallel_0_11/zeros/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/parallel_0_11/zeros" - op: "Fill" - input: "while_loop/parallel_0_11/zeros/shape_as_tensor" - input: "while_loop/parallel_0_11/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "while_loop/parallel_0_12/zeros" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - dim { - size: 32 - } - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/parallel_0_13/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200\000\000\000" - } - } - } -} -node { - name: "while_loop/parallel_0_13/zeros/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/parallel_0_13/zeros" - op: "Fill" - input: "while_loop/parallel_0_13/zeros/shape_as_tensor" - input: "while_loop/parallel_0_13/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "while_loop/parallel_0_14/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200\000\000\000" - } - } - } -} -node { - name: "while_loop/parallel_0_14/zeros/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/parallel_0_14/zeros" - op: "Fill" - input: "while_loop/parallel_0_14/zeros/shape_as_tensor" - input: "while_loop/parallel_0_14/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "while_loop/parallel_0_15/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200\000\000\000" - } - } - } -} -node { - name: "while_loop/parallel_0_15/zeros/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/parallel_0_15/zeros" - op: "Fill" - input: "while_loop/parallel_0_15/zeros/shape_as_tensor" - input: "while_loop/parallel_0_15/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "while_loop/parallel_0_16/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\200\000\000\000 \000\000\000" - } - } - } -} -node { - name: "while_loop/parallel_0_16/zeros/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/parallel_0_16/zeros" - op: "Fill" - input: "while_loop/parallel_0_16/zeros/shape_as_tensor" - input: "while_loop/parallel_0_16/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "while_loop/parallel_0_17/zeros" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - dim { - size: 32 - } - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/parallel_0_18/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000@\000\000\000" - } - } - } -} -node { - name: "while_loop/parallel_0_18/zeros/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/parallel_0_18/zeros" - op: "Fill" - input: "while_loop/parallel_0_18/zeros/shape_as_tensor" - input: "while_loop/parallel_0_18/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "while_loop/parallel_0_19/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000@\000\000\000" - } - } - } -} -node { - name: "while_loop/parallel_0_19/zeros/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/parallel_0_19/zeros" - op: "Fill" - input: "while_loop/parallel_0_19/zeros/shape_as_tensor" - input: "while_loop/parallel_0_19/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "while_loop/parallel_0_20/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "@\000\000\000 \000\000\000" - } - } - } -} -node { - name: "while_loop/parallel_0_20/zeros/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/parallel_0_20/zeros" - op: "Fill" - input: "while_loop/parallel_0_20/zeros/shape_as_tensor" - input: "while_loop/parallel_0_20/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "while_loop/parallel_0_21/zeros" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - dim { - size: 32 - } - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/parallel_0_22/zeros" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - dim { - size: 32 - } - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/parallel_0_23/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200\000\000\000" - } - } - } -} -node { - name: "while_loop/parallel_0_23/zeros/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/parallel_0_23/zeros" - op: "Fill" - input: "while_loop/parallel_0_23/zeros/shape_as_tensor" - input: "while_loop/parallel_0_23/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "while_loop/parallel_0_24/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200\000\000\000" - } - } - } -} -node { - name: "while_loop/parallel_0_24/zeros/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/parallel_0_24/zeros" - op: "Fill" - input: "while_loop/parallel_0_24/zeros/shape_as_tensor" - input: "while_loop/parallel_0_24/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "while_loop/parallel_0_25/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200\000\000\000" - } - } - } -} -node { - name: "while_loop/parallel_0_25/zeros/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/parallel_0_25/zeros" - op: "Fill" - input: "while_loop/parallel_0_25/zeros/shape_as_tensor" - input: "while_loop/parallel_0_25/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "while_loop/parallel_0_26/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\200\000\000\000 \000\000\000" - } - } - } -} -node { - name: "while_loop/parallel_0_26/zeros/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/parallel_0_26/zeros" - op: "Fill" - input: "while_loop/parallel_0_26/zeros/shape_as_tensor" - input: "while_loop/parallel_0_26/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "while_loop/parallel_0_27/zeros" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/parallel_0_28/zeros" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - dim { - size: 32 - } - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/parallel_0_29/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200\000\000\000" - } - } - } -} -node { - name: "while_loop/parallel_0_29/zeros/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/parallel_0_29/zeros" - op: "Fill" - input: "while_loop/parallel_0_29/zeros/shape_as_tensor" - input: "while_loop/parallel_0_29/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "while_loop/parallel_0_30/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200\000\000\000" - } - } - } -} -node { - name: "while_loop/parallel_0_30/zeros/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/parallel_0_30/zeros" - op: "Fill" - input: "while_loop/parallel_0_30/zeros/shape_as_tensor" - input: "while_loop/parallel_0_30/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "while_loop/parallel_0_31/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200\000\000\000" - } - } - } -} -node { - name: "while_loop/parallel_0_31/zeros/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/parallel_0_31/zeros" - op: "Fill" - input: "while_loop/parallel_0_31/zeros/shape_as_tensor" - input: "while_loop/parallel_0_31/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "while_loop/parallel_0_32/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\200\000\000\000 \000\000\000" - } - } - } -} -node { - name: "while_loop/parallel_0_32/zeros/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/parallel_0_32/zeros" - op: "Fill" - input: "while_loop/parallel_0_32/zeros/shape_as_tensor" - input: "while_loop/parallel_0_32/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "while_loop/parallel_0_33/zeros" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - dim { - size: 32 - } - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/parallel_0_34/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000@\000\000\000" - } - } - } -} -node { - name: "while_loop/parallel_0_34/zeros/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/parallel_0_34/zeros" - op: "Fill" - input: "while_loop/parallel_0_34/zeros/shape_as_tensor" - input: "while_loop/parallel_0_34/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "while_loop/parallel_0_35/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000@\000\000\000" - } - } - } -} -node { - name: "while_loop/parallel_0_35/zeros/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/parallel_0_35/zeros" - op: "Fill" - input: "while_loop/parallel_0_35/zeros/shape_as_tensor" - input: "while_loop/parallel_0_35/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "while_loop/parallel_0_36/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "@\000\000\000 \000\000\000" - } - } - } -} -node { - name: "while_loop/parallel_0_36/zeros/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/parallel_0_36/zeros" - op: "Fill" - input: "while_loop/parallel_0_36/zeros/shape_as_tensor" - input: "while_loop/parallel_0_36/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "while_loop/parallel_0_37/zeros" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - dim { - size: 32 - } - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/parallel_0_38/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200\000\000\000" - } - } - } -} -node { - name: "while_loop/parallel_0_38/zeros/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/parallel_0_38/zeros" - op: "Fill" - input: "while_loop/parallel_0_38/zeros/shape_as_tensor" - input: "while_loop/parallel_0_38/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "while_loop/parallel_0_39/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200\000\000\000" - } - } - } -} -node { - name: "while_loop/parallel_0_39/zeros/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/parallel_0_39/zeros" - op: "Fill" - input: "while_loop/parallel_0_39/zeros/shape_as_tensor" - input: "while_loop/parallel_0_39/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "while_loop/parallel_0_40/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200\000\000\000" - } - } - } -} -node { - name: "while_loop/parallel_0_40/zeros/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/parallel_0_40/zeros" - op: "Fill" - input: "while_loop/parallel_0_40/zeros/shape_as_tensor" - input: "while_loop/parallel_0_40/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "while_loop/parallel_0_41/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\200\000\000\000 \000\000\000" - } - } - } -} -node { - name: "while_loop/parallel_0_41/zeros/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/parallel_0_41/zeros" - op: "Fill" - input: "while_loop/parallel_0_41/zeros/shape_as_tensor" - input: "while_loop/parallel_0_41/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "while_loop/parallel_0_42/zeros" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - dim { - size: 32 - } - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/parallel_0_43/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200\000\000\000" - } - } - } -} -node { - name: "while_loop/parallel_0_43/zeros/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/parallel_0_43/zeros" - op: "Fill" - input: "while_loop/parallel_0_43/zeros/shape_as_tensor" - input: "while_loop/parallel_0_43/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "while_loop/parallel_0_44/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200\000\000\000" - } - } - } -} -node { - name: "while_loop/parallel_0_44/zeros/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/parallel_0_44/zeros" - op: "Fill" - input: "while_loop/parallel_0_44/zeros/shape_as_tensor" - input: "while_loop/parallel_0_44/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "while_loop/parallel_0_45/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200\000\000\000" - } - } - } -} -node { - name: "while_loop/parallel_0_45/zeros/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/parallel_0_45/zeros" - op: "Fill" - input: "while_loop/parallel_0_45/zeros/shape_as_tensor" - input: "while_loop/parallel_0_45/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "while_loop/parallel_0_46/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\200\000\000\000 \000\000\000" - } - } - } -} -node { - name: "while_loop/parallel_0_46/zeros/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/parallel_0_46/zeros" - op: "Fill" - input: "while_loop/parallel_0_46/zeros/shape_as_tensor" - input: "while_loop/parallel_0_46/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "while_loop/parallel_0_47/zeros" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - dim { - size: 32 - } - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/parallel_0_48/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000@\000\000\000" - } - } - } -} -node { - name: "while_loop/parallel_0_48/zeros/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/parallel_0_48/zeros" - op: "Fill" - input: "while_loop/parallel_0_48/zeros/shape_as_tensor" - input: "while_loop/parallel_0_48/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "while_loop/parallel_0_49/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000@\000\000\000" - } - } - } -} -node { - name: "while_loop/parallel_0_49/zeros/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/parallel_0_49/zeros" - op: "Fill" - input: "while_loop/parallel_0_49/zeros/shape_as_tensor" - input: "while_loop/parallel_0_49/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "while_loop/parallel_0_50/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "@\000\000\000 \000\000\000" - } - } - } -} -node { - name: "while_loop/parallel_0_50/zeros/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/parallel_0_50/zeros" - op: "Fill" - input: "while_loop/parallel_0_50/zeros/shape_as_tensor" - input: "while_loop/parallel_0_50/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "while_loop/parallel_0_51/zeros" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - dim { - size: 32 - } - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/parallel_0_52/zeros/shape_as_tensor" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: " \000\000\000\200}\000\000" - } - } - } -} -node { - name: "while_loop/parallel_0_52/zeros/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/parallel_0_52/zeros" - op: "Fill" - input: "while_loop/parallel_0_52/zeros/shape_as_tensor" - input: "while_loop/parallel_0_52/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 32128 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "while_loop/while/Enter" - op: "Enter" - input: "constant/parallel_0/Const" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: false - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/Enter_1" - op: "Enter" - input: "while_loop/parallel_0/zeros" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: false - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/Enter_2" - op: "Enter" - input: "while_loop/parallel_0_1/zeros" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: false - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/Enter_3" - op: "Enter" - input: "while_loop/parallel_0_2/zeros" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: false - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/Enter_4" - op: "Enter" - input: "while_loop/parallel_0_3/zeros" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: false - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/Enter_5" - op: "Enter" - input: "while_loop/parallel_0_4/zeros" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: false - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/Enter_6" - op: "Enter" - input: "while_loop/parallel_0_5/zeros" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: false - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/Enter_7" - op: "Enter" - input: "while_loop/parallel_0_6/zeros" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: false - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/Enter_8" - op: "Enter" - input: "while_loop/parallel_0_7/zeros" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: false - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/Enter_9" - op: "Enter" - input: "while_loop/parallel_0_8/zeros" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: false - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/Enter_10" - op: "Enter" - input: "while_loop/parallel_0_9/zeros" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: false - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/Enter_11" - op: "Enter" - input: "while_loop/parallel_0_10/zeros" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: false - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/Enter_12" - op: "Enter" - input: "while_loop/parallel_0_11/zeros" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: false - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/Enter_13" - op: "Enter" - input: "while_loop/parallel_0_12/zeros" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: false - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/Enter_14" - op: "Enter" - input: "while_loop/parallel_0_13/zeros" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: false - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/Enter_15" - op: "Enter" - input: "while_loop/parallel_0_14/zeros" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: false - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/Enter_16" - op: "Enter" - input: "while_loop/parallel_0_15/zeros" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: false - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/Enter_17" - op: "Enter" - input: "while_loop/parallel_0_16/zeros" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: false - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/Enter_18" - op: "Enter" - input: "while_loop/parallel_0_17/zeros" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: false - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/Enter_19" - op: "Enter" - input: "while_loop/parallel_0_18/zeros" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: false - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/Enter_20" - op: "Enter" - input: "while_loop/parallel_0_19/zeros" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: false - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/Enter_21" - op: "Enter" - input: "while_loop/parallel_0_20/zeros" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: false - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/Enter_22" - op: "Enter" - input: "while_loop/parallel_0_21/zeros" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: false - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/Enter_23" - op: "Enter" - input: "while_loop/parallel_0_22/zeros" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: false - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/Enter_24" - op: "Enter" - input: "while_loop/parallel_0_23/zeros" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: false - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/Enter_25" - op: "Enter" - input: "while_loop/parallel_0_24/zeros" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: false - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/Enter_26" - op: "Enter" - input: "while_loop/parallel_0_25/zeros" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: false - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/Enter_27" - op: "Enter" - input: "while_loop/parallel_0_26/zeros" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: false - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/Enter_28" - op: "Enter" - input: "while_loop/parallel_0_27/zeros" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: false - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/Enter_29" - op: "Enter" - input: "while_loop/parallel_0_28/zeros" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: false - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/Enter_30" - op: "Enter" - input: "while_loop/parallel_0_29/zeros" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: false - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/Enter_31" - op: "Enter" - input: "while_loop/parallel_0_30/zeros" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: false - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/Enter_32" - op: "Enter" - input: "while_loop/parallel_0_31/zeros" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: false - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/Enter_33" - op: "Enter" - input: "while_loop/parallel_0_32/zeros" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: false - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/Enter_34" - op: "Enter" - input: "while_loop/parallel_0_33/zeros" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: false - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/Enter_35" - op: "Enter" - input: "while_loop/parallel_0_34/zeros" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: false - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/Enter_36" - op: "Enter" - input: "while_loop/parallel_0_35/zeros" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: false - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/Enter_37" - op: "Enter" - input: "while_loop/parallel_0_36/zeros" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: false - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/Enter_38" - op: "Enter" - input: "while_loop/parallel_0_37/zeros" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: false - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/Enter_39" - op: "Enter" - input: "while_loop/parallel_0_38/zeros" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: false - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/Enter_40" - op: "Enter" - input: "while_loop/parallel_0_39/zeros" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: false - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/Enter_41" - op: "Enter" - input: "while_loop/parallel_0_40/zeros" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: false - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/Enter_42" - op: "Enter" - input: "while_loop/parallel_0_41/zeros" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: false - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/Enter_43" - op: "Enter" - input: "while_loop/parallel_0_42/zeros" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: false - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/Enter_44" - op: "Enter" - input: "while_loop/parallel_0_43/zeros" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: false - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/Enter_45" - op: "Enter" - input: "while_loop/parallel_0_44/zeros" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: false - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/Enter_46" - op: "Enter" - input: "while_loop/parallel_0_45/zeros" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: false - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/Enter_47" - op: "Enter" - input: "while_loop/parallel_0_46/zeros" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: false - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/Enter_48" - op: "Enter" - input: "while_loop/parallel_0_47/zeros" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: false - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/Enter_49" - op: "Enter" - input: "while_loop/parallel_0_48/zeros" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: false - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/Enter_50" - op: "Enter" - input: "while_loop/parallel_0_49/zeros" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: false - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/Enter_51" - op: "Enter" - input: "while_loop/parallel_0_50/zeros" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: false - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/Enter_52" - op: "Enter" - input: "while_loop/parallel_0_51/zeros" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: false - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/Enter_53" - op: "Enter" - input: "while_loop/parallel_0_52/zeros" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 32128 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: false - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/Merge" - op: "Merge" - input: "while_loop/while/Enter" - input: "while_loop/while/NextIteration" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - shape { - } - } - } - } -} -node { - name: "while_loop/while/Merge_1" - op: "Merge" - input: "while_loop/while/Enter_1" - input: "while_loop/while/NextIteration_1" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - shape { - } - } - } - } -} -node { - name: "while_loop/while/Merge_2" - op: "Merge" - input: "while_loop/while/Enter_2" - input: "while_loop/while/NextIteration_2" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32128 - } - dim { - size: 32 - } - } - shape { - } - } - } - } -} -node { - name: "while_loop/while/Merge_3" - op: "Merge" - input: "while_loop/while/Enter_3" - input: "while_loop/while/NextIteration_3" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - shape { - } - } - } - } -} -node { - name: "while_loop/while/Merge_4" - op: "Merge" - input: "while_loop/while/Enter_4" - input: "while_loop/while/NextIteration_4" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - shape { - } - } - } - } -} -node { - name: "while_loop/while/Merge_5" - op: "Merge" - input: "while_loop/while/Enter_5" - input: "while_loop/while/NextIteration_5" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - shape { - } - } - } - } -} -node { - name: "while_loop/while/Merge_6" - op: "Merge" - input: "while_loop/while/Enter_6" - input: "while_loop/while/NextIteration_6" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - shape { - } - } - } - } -} -node { - name: "while_loop/while/Merge_7" - op: "Merge" - input: "while_loop/while/Enter_7" - input: "while_loop/while/NextIteration_7" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - shape { - } - } - } - } -} -node { - name: "while_loop/while/Merge_8" - op: "Merge" - input: "while_loop/while/Enter_8" - input: "while_loop/while/NextIteration_8" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - shape { - } - } - } - } -} -node { - name: "while_loop/while/Merge_9" - op: "Merge" - input: "while_loop/while/Enter_9" - input: "while_loop/while/NextIteration_9" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - shape { - } - } - } - } -} -node { - name: "while_loop/while/Merge_10" - op: "Merge" - input: "while_loop/while/Enter_10" - input: "while_loop/while/NextIteration_10" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - shape { - } - } - } - } -} -node { - name: "while_loop/while/Merge_11" - op: "Merge" - input: "while_loop/while/Enter_11" - input: "while_loop/while/NextIteration_11" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - shape { - } - } - } - } -} -node { - name: "while_loop/while/Merge_12" - op: "Merge" - input: "while_loop/while/Enter_12" - input: "while_loop/while/NextIteration_12" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - shape { - } - } - } - } -} -node { - name: "while_loop/while/Merge_13" - op: "Merge" - input: "while_loop/while/Enter_13" - input: "while_loop/while/NextIteration_13" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - shape { - } - } - } - } -} -node { - name: "while_loop/while/Merge_14" - op: "Merge" - input: "while_loop/while/Enter_14" - input: "while_loop/while/NextIteration_14" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - shape { - } - } - } - } -} -node { - name: "while_loop/while/Merge_15" - op: "Merge" - input: "while_loop/while/Enter_15" - input: "while_loop/while/NextIteration_15" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - shape { - } - } - } - } -} -node { - name: "while_loop/while/Merge_16" - op: "Merge" - input: "while_loop/while/Enter_16" - input: "while_loop/while/NextIteration_16" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - shape { - } - } - } - } -} -node { - name: "while_loop/while/Merge_17" - op: "Merge" - input: "while_loop/while/Enter_17" - input: "while_loop/while/NextIteration_17" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - shape { - } - } - } - } -} -node { - name: "while_loop/while/Merge_18" - op: "Merge" - input: "while_loop/while/Enter_18" - input: "while_loop/while/NextIteration_18" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - shape { - } - } - } - } -} -node { - name: "while_loop/while/Merge_19" - op: "Merge" - input: "while_loop/while/Enter_19" - input: "while_loop/while/NextIteration_19" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - shape { - } - } - } - } -} -node { - name: "while_loop/while/Merge_20" - op: "Merge" - input: "while_loop/while/Enter_20" - input: "while_loop/while/NextIteration_20" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - shape { - } - } - } - } -} -node { - name: "while_loop/while/Merge_21" - op: "Merge" - input: "while_loop/while/Enter_21" - input: "while_loop/while/NextIteration_21" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - shape { - } - } - } - } -} -node { - name: "while_loop/while/Merge_22" - op: "Merge" - input: "while_loop/while/Enter_22" - input: "while_loop/while/NextIteration_22" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - shape { - } - } - } - } -} -node { - name: "while_loop/while/Merge_23" - op: "Merge" - input: "while_loop/while/Enter_23" - input: "while_loop/while/NextIteration_23" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - shape { - } - } - } - } -} -node { - name: "while_loop/while/Merge_24" - op: "Merge" - input: "while_loop/while/Enter_24" - input: "while_loop/while/NextIteration_24" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - shape { - } - } - } - } -} -node { - name: "while_loop/while/Merge_25" - op: "Merge" - input: "while_loop/while/Enter_25" - input: "while_loop/while/NextIteration_25" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - shape { - } - } - } - } -} -node { - name: "while_loop/while/Merge_26" - op: "Merge" - input: "while_loop/while/Enter_26" - input: "while_loop/while/NextIteration_26" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - shape { - } - } - } - } -} -node { - name: "while_loop/while/Merge_27" - op: "Merge" - input: "while_loop/while/Enter_27" - input: "while_loop/while/NextIteration_27" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - shape { - } - } - } - } -} -node { - name: "while_loop/while/Merge_28" - op: "Merge" - input: "while_loop/while/Enter_28" - input: "while_loop/while/NextIteration_28" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - shape { - } - } - } - } -} -node { - name: "while_loop/while/Merge_29" - op: "Merge" - input: "while_loop/while/Enter_29" - input: "while_loop/while/NextIteration_29" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - shape { - } - } - } - } -} -node { - name: "while_loop/while/Merge_30" - op: "Merge" - input: "while_loop/while/Enter_30" - input: "while_loop/while/NextIteration_30" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - shape { - } - } - } - } -} -node { - name: "while_loop/while/Merge_31" - op: "Merge" - input: "while_loop/while/Enter_31" - input: "while_loop/while/NextIteration_31" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - shape { - } - } - } - } -} -node { - name: "while_loop/while/Merge_32" - op: "Merge" - input: "while_loop/while/Enter_32" - input: "while_loop/while/NextIteration_32" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - shape { - } - } - } - } -} -node { - name: "while_loop/while/Merge_33" - op: "Merge" - input: "while_loop/while/Enter_33" - input: "while_loop/while/NextIteration_33" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - shape { - } - } - } - } -} -node { - name: "while_loop/while/Merge_34" - op: "Merge" - input: "while_loop/while/Enter_34" - input: "while_loop/while/NextIteration_34" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - shape { - } - } - } - } -} -node { - name: "while_loop/while/Merge_35" - op: "Merge" - input: "while_loop/while/Enter_35" - input: "while_loop/while/NextIteration_35" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - shape { - } - } - } - } -} -node { - name: "while_loop/while/Merge_36" - op: "Merge" - input: "while_loop/while/Enter_36" - input: "while_loop/while/NextIteration_36" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - shape { - } - } - } - } -} -node { - name: "while_loop/while/Merge_37" - op: "Merge" - input: "while_loop/while/Enter_37" - input: "while_loop/while/NextIteration_37" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - shape { - } - } - } - } -} -node { - name: "while_loop/while/Merge_38" - op: "Merge" - input: "while_loop/while/Enter_38" - input: "while_loop/while/NextIteration_38" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - shape { - } - } - } - } -} -node { - name: "while_loop/while/Merge_39" - op: "Merge" - input: "while_loop/while/Enter_39" - input: "while_loop/while/NextIteration_39" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - shape { - } - } - } - } -} -node { - name: "while_loop/while/Merge_40" - op: "Merge" - input: "while_loop/while/Enter_40" - input: "while_loop/while/NextIteration_40" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - shape { - } - } - } - } -} -node { - name: "while_loop/while/Merge_41" - op: "Merge" - input: "while_loop/while/Enter_41" - input: "while_loop/while/NextIteration_41" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - shape { - } - } - } - } -} -node { - name: "while_loop/while/Merge_42" - op: "Merge" - input: "while_loop/while/Enter_42" - input: "while_loop/while/NextIteration_42" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - shape { - } - } - } - } -} -node { - name: "while_loop/while/Merge_43" - op: "Merge" - input: "while_loop/while/Enter_43" - input: "while_loop/while/NextIteration_43" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - shape { - } - } - } - } -} -node { - name: "while_loop/while/Merge_44" - op: "Merge" - input: "while_loop/while/Enter_44" - input: "while_loop/while/NextIteration_44" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - shape { - } - } - } - } -} -node { - name: "while_loop/while/Merge_45" - op: "Merge" - input: "while_loop/while/Enter_45" - input: "while_loop/while/NextIteration_45" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - shape { - } - } - } - } -} -node { - name: "while_loop/while/Merge_46" - op: "Merge" - input: "while_loop/while/Enter_46" - input: "while_loop/while/NextIteration_46" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - shape { - } - } - } - } -} -node { - name: "while_loop/while/Merge_47" - op: "Merge" - input: "while_loop/while/Enter_47" - input: "while_loop/while/NextIteration_47" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - shape { - } - } - } - } -} -node { - name: "while_loop/while/Merge_48" - op: "Merge" - input: "while_loop/while/Enter_48" - input: "while_loop/while/NextIteration_48" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - shape { - } - } - } - } -} -node { - name: "while_loop/while/Merge_49" - op: "Merge" - input: "while_loop/while/Enter_49" - input: "while_loop/while/NextIteration_49" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - shape { - } - } - } - } -} -node { - name: "while_loop/while/Merge_50" - op: "Merge" - input: "while_loop/while/Enter_50" - input: "while_loop/while/NextIteration_50" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - shape { - } - } - } - } -} -node { - name: "while_loop/while/Merge_51" - op: "Merge" - input: "while_loop/while/Enter_51" - input: "while_loop/while/NextIteration_51" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - shape { - } - } - } - } -} -node { - name: "while_loop/while/Merge_52" - op: "Merge" - input: "while_loop/while/Enter_52" - input: "while_loop/while/NextIteration_52" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - shape { - } - } - } - } -} -node { - name: "while_loop/while/Merge_53" - op: "Merge" - input: "while_loop/while/Enter_53" - input: "while_loop/while/NextIteration_53" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 32128 - } - } - shape { - } - } - } - } -} -node { - name: "while_loop/while/binary_op/parallel_0/Less" - op: "Less" - input: "while_loop/while/Merge" - input: "while_loop/while/binary_op/parallel_0/Less/Enter" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "while_loop/while/binary_op/parallel_0/Less/Enter" - op: "Enter" - input: "Const_2" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/LoopCond" - op: "LoopCond" - input: "while_loop/while/binary_op/parallel_0/Less" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "while_loop/while/Switch" - op: "Switch" - input: "while_loop/while/Merge" - input: "while_loop/while/LoopCond" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@while_loop/while/Merge" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - shape { - } - } - } - } -} -node { - name: "while_loop/while/Switch_1" - op: "Switch" - input: "while_loop/while/Merge_1" - input: "while_loop/while/LoopCond" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@while_loop/while/Merge_1" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - shape { - } - } - } - } -} -node { - name: "while_loop/while/Switch_2" - op: "Switch" - input: "while_loop/while/Merge_2" - input: "while_loop/while/LoopCond" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@while_loop/while/Merge_2" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32128 - } - dim { - size: 32 - } - } - shape { - dim { - size: 32128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Switch_3" - op: "Switch" - input: "while_loop/while/Merge_3" - input: "while_loop/while/LoopCond" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@while_loop/while/Merge_3" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Switch_4" - op: "Switch" - input: "while_loop/while/Merge_4" - input: "while_loop/while/LoopCond" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@while_loop/while/Merge_4" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/Switch_5" - op: "Switch" - input: "while_loop/while/Merge_5" - input: "while_loop/while/LoopCond" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@while_loop/while/Merge_5" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/Switch_6" - op: "Switch" - input: "while_loop/while/Merge_6" - input: "while_loop/while/LoopCond" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@while_loop/while/Merge_6" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/Switch_7" - op: "Switch" - input: "while_loop/while/Merge_7" - input: "while_loop/while/LoopCond" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@while_loop/while/Merge_7" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Switch_8" - op: "Switch" - input: "while_loop/while/Merge_8" - input: "while_loop/while/LoopCond" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@while_loop/while/Merge_8" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Switch_9" - op: "Switch" - input: "while_loop/while/Merge_9" - input: "while_loop/while/LoopCond" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@while_loop/while/Merge_9" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Switch_10" - op: "Switch" - input: "while_loop/while/Merge_10" - input: "while_loop/while/LoopCond" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@while_loop/while/Merge_10" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/Switch_11" - op: "Switch" - input: "while_loop/while/Merge_11" - input: "while_loop/while/LoopCond" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@while_loop/while/Merge_11" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/Switch_12" - op: "Switch" - input: "while_loop/while/Merge_12" - input: "while_loop/while/LoopCond" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@while_loop/while/Merge_12" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Switch_13" - op: "Switch" - input: "while_loop/while/Merge_13" - input: "while_loop/while/LoopCond" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@while_loop/while/Merge_13" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Switch_14" - op: "Switch" - input: "while_loop/while/Merge_14" - input: "while_loop/while/LoopCond" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@while_loop/while/Merge_14" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/Switch_15" - op: "Switch" - input: "while_loop/while/Merge_15" - input: "while_loop/while/LoopCond" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@while_loop/while/Merge_15" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/Switch_16" - op: "Switch" - input: "while_loop/while/Merge_16" - input: "while_loop/while/LoopCond" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@while_loop/while/Merge_16" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/Switch_17" - op: "Switch" - input: "while_loop/while/Merge_17" - input: "while_loop/while/LoopCond" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@while_loop/while/Merge_17" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Switch_18" - op: "Switch" - input: "while_loop/while/Merge_18" - input: "while_loop/while/LoopCond" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@while_loop/while/Merge_18" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Switch_19" - op: "Switch" - input: "while_loop/while/Merge_19" - input: "while_loop/while/LoopCond" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@while_loop/while/Merge_19" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/Switch_20" - op: "Switch" - input: "while_loop/while/Merge_20" - input: "while_loop/while/LoopCond" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@while_loop/while/Merge_20" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/Switch_21" - op: "Switch" - input: "while_loop/while/Merge_21" - input: "while_loop/while/LoopCond" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@while_loop/while/Merge_21" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Switch_22" - op: "Switch" - input: "while_loop/while/Merge_22" - input: "while_loop/while/LoopCond" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@while_loop/while/Merge_22" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Switch_23" - op: "Switch" - input: "while_loop/while/Merge_23" - input: "while_loop/while/LoopCond" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@while_loop/while/Merge_23" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Switch_24" - op: "Switch" - input: "while_loop/while/Merge_24" - input: "while_loop/while/LoopCond" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@while_loop/while/Merge_24" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/Switch_25" - op: "Switch" - input: "while_loop/while/Merge_25" - input: "while_loop/while/LoopCond" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@while_loop/while/Merge_25" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/Switch_26" - op: "Switch" - input: "while_loop/while/Merge_26" - input: "while_loop/while/LoopCond" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@while_loop/while/Merge_26" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/Switch_27" - op: "Switch" - input: "while_loop/while/Merge_27" - input: "while_loop/while/LoopCond" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@while_loop/while/Merge_27" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Switch_28" - op: "Switch" - input: "while_loop/while/Merge_28" - input: "while_loop/while/LoopCond" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@while_loop/while/Merge_28" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Switch_29" - op: "Switch" - input: "while_loop/while/Merge_29" - input: "while_loop/while/LoopCond" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@while_loop/while/Merge_29" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Switch_30" - op: "Switch" - input: "while_loop/while/Merge_30" - input: "while_loop/while/LoopCond" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@while_loop/while/Merge_30" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/Switch_31" - op: "Switch" - input: "while_loop/while/Merge_31" - input: "while_loop/while/LoopCond" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@while_loop/while/Merge_31" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/Switch_32" - op: "Switch" - input: "while_loop/while/Merge_32" - input: "while_loop/while/LoopCond" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@while_loop/while/Merge_32" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/Switch_33" - op: "Switch" - input: "while_loop/while/Merge_33" - input: "while_loop/while/LoopCond" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@while_loop/while/Merge_33" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Switch_34" - op: "Switch" - input: "while_loop/while/Merge_34" - input: "while_loop/while/LoopCond" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@while_loop/while/Merge_34" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Switch_35" - op: "Switch" - input: "while_loop/while/Merge_35" - input: "while_loop/while/LoopCond" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@while_loop/while/Merge_35" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/Switch_36" - op: "Switch" - input: "while_loop/while/Merge_36" - input: "while_loop/while/LoopCond" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@while_loop/while/Merge_36" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/Switch_37" - op: "Switch" - input: "while_loop/while/Merge_37" - input: "while_loop/while/LoopCond" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@while_loop/while/Merge_37" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Switch_38" - op: "Switch" - input: "while_loop/while/Merge_38" - input: "while_loop/while/LoopCond" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@while_loop/while/Merge_38" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Switch_39" - op: "Switch" - input: "while_loop/while/Merge_39" - input: "while_loop/while/LoopCond" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@while_loop/while/Merge_39" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/Switch_40" - op: "Switch" - input: "while_loop/while/Merge_40" - input: "while_loop/while/LoopCond" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@while_loop/while/Merge_40" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/Switch_41" - op: "Switch" - input: "while_loop/while/Merge_41" - input: "while_loop/while/LoopCond" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@while_loop/while/Merge_41" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/Switch_42" - op: "Switch" - input: "while_loop/while/Merge_42" - input: "while_loop/while/LoopCond" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@while_loop/while/Merge_42" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Switch_43" - op: "Switch" - input: "while_loop/while/Merge_43" - input: "while_loop/while/LoopCond" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@while_loop/while/Merge_43" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Switch_44" - op: "Switch" - input: "while_loop/while/Merge_44" - input: "while_loop/while/LoopCond" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@while_loop/while/Merge_44" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/Switch_45" - op: "Switch" - input: "while_loop/while/Merge_45" - input: "while_loop/while/LoopCond" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@while_loop/while/Merge_45" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/Switch_46" - op: "Switch" - input: "while_loop/while/Merge_46" - input: "while_loop/while/LoopCond" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@while_loop/while/Merge_46" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/Switch_47" - op: "Switch" - input: "while_loop/while/Merge_47" - input: "while_loop/while/LoopCond" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@while_loop/while/Merge_47" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Switch_48" - op: "Switch" - input: "while_loop/while/Merge_48" - input: "while_loop/while/LoopCond" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@while_loop/while/Merge_48" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Switch_49" - op: "Switch" - input: "while_loop/while/Merge_49" - input: "while_loop/while/LoopCond" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@while_loop/while/Merge_49" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/Switch_50" - op: "Switch" - input: "while_loop/while/Merge_50" - input: "while_loop/while/LoopCond" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@while_loop/while/Merge_50" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/Switch_51" - op: "Switch" - input: "while_loop/while/Merge_51" - input: "while_loop/while/LoopCond" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@while_loop/while/Merge_51" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Switch_52" - op: "Switch" - input: "while_loop/while/Merge_52" - input: "while_loop/while/LoopCond" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@while_loop/while/Merge_52" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Switch_53" - op: "Switch" - input: "while_loop/while/Merge_53" - input: "while_loop/while/LoopCond" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@while_loop/while/Merge_53" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 32128 - } - } - shape { - dim { - size: 32 - } - dim { - size: 32128 - } - } - } - } - } -} -node { - name: "while_loop/while/Identity" - op: "Identity" - input: "while_loop/while/Switch:1" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "while_loop/while/Identity_1" - op: "Identity" - input: "while_loop/while/Switch_1:1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "while_loop/while/Identity_2" - op: "Identity" - input: "while_loop/while/Switch_2:1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Identity_3" - op: "Identity" - input: "while_loop/while/Switch_3:1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Identity_4" - op: "Identity" - input: "while_loop/while/Switch_4:1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/Identity_5" - op: "Identity" - input: "while_loop/while/Switch_5:1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/Identity_6" - op: "Identity" - input: "while_loop/while/Switch_6:1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/Identity_7" - op: "Identity" - input: "while_loop/while/Switch_7:1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Identity_8" - op: "Identity" - input: "while_loop/while/Switch_8:1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Identity_9" - op: "Identity" - input: "while_loop/while/Switch_9:1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Identity_10" - op: "Identity" - input: "while_loop/while/Switch_10:1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/Identity_11" - op: "Identity" - input: "while_loop/while/Switch_11:1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/Identity_12" - op: "Identity" - input: "while_loop/while/Switch_12:1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Identity_13" - op: "Identity" - input: "while_loop/while/Switch_13:1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Identity_14" - op: "Identity" - input: "while_loop/while/Switch_14:1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/Identity_15" - op: "Identity" - input: "while_loop/while/Switch_15:1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/Identity_16" - op: "Identity" - input: "while_loop/while/Switch_16:1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/Identity_17" - op: "Identity" - input: "while_loop/while/Switch_17:1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Identity_18" - op: "Identity" - input: "while_loop/while/Switch_18:1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Identity_19" - op: "Identity" - input: "while_loop/while/Switch_19:1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/Identity_20" - op: "Identity" - input: "while_loop/while/Switch_20:1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/Identity_21" - op: "Identity" - input: "while_loop/while/Switch_21:1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Identity_22" - op: "Identity" - input: "while_loop/while/Switch_22:1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Identity_23" - op: "Identity" - input: "while_loop/while/Switch_23:1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Identity_24" - op: "Identity" - input: "while_loop/while/Switch_24:1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/Identity_25" - op: "Identity" - input: "while_loop/while/Switch_25:1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/Identity_26" - op: "Identity" - input: "while_loop/while/Switch_26:1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/Identity_27" - op: "Identity" - input: "while_loop/while/Switch_27:1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Identity_28" - op: "Identity" - input: "while_loop/while/Switch_28:1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Identity_29" - op: "Identity" - input: "while_loop/while/Switch_29:1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Identity_30" - op: "Identity" - input: "while_loop/while/Switch_30:1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/Identity_31" - op: "Identity" - input: "while_loop/while/Switch_31:1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/Identity_32" - op: "Identity" - input: "while_loop/while/Switch_32:1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/Identity_33" - op: "Identity" - input: "while_loop/while/Switch_33:1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Identity_34" - op: "Identity" - input: "while_loop/while/Switch_34:1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Identity_35" - op: "Identity" - input: "while_loop/while/Switch_35:1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/Identity_36" - op: "Identity" - input: "while_loop/while/Switch_36:1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/Identity_37" - op: "Identity" - input: "while_loop/while/Switch_37:1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Identity_38" - op: "Identity" - input: "while_loop/while/Switch_38:1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Identity_39" - op: "Identity" - input: "while_loop/while/Switch_39:1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/Identity_40" - op: "Identity" - input: "while_loop/while/Switch_40:1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/Identity_41" - op: "Identity" - input: "while_loop/while/Switch_41:1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/Identity_42" - op: "Identity" - input: "while_loop/while/Switch_42:1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Identity_43" - op: "Identity" - input: "while_loop/while/Switch_43:1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Identity_44" - op: "Identity" - input: "while_loop/while/Switch_44:1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/Identity_45" - op: "Identity" - input: "while_loop/while/Switch_45:1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/Identity_46" - op: "Identity" - input: "while_loop/while/Switch_46:1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/Identity_47" - op: "Identity" - input: "while_loop/while/Switch_47:1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Identity_48" - op: "Identity" - input: "while_loop/while/Switch_48:1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Identity_49" - op: "Identity" - input: "while_loop/while/Switch_49:1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/Identity_50" - op: "Identity" - input: "while_loop/while/Switch_50:1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/Identity_51" - op: "Identity" - input: "while_loop/while/Switch_51:1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Identity_52" - op: "Identity" - input: "while_loop/while/Switch_52:1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Identity_53" - op: "Identity" - input: "while_loop/while/Switch_53:1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 32128 - } - } - } - } - } -} -node { - name: "while_loop/while/reshape_6/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\004\000\000\000\000\002\000\000" - } - } - } -} -node { - name: "while_loop/while/reshape_6/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/reshape_6/parallel_0/Reshape/Enter" - input: "while_loop/while/reshape_6/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 4 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/reshape_6/parallel_0/Reshape/Enter" - op: "Enter" - input: "reshape_6/parallel_0/Reshape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 32 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/one_hot/parallel_0/sub/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/one_hot/parallel_0/sub" - op: "Sub" - input: "while_loop/while/Identity" - input: "while_loop/while/one_hot/parallel_0/sub/y" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "while_loop/while/one_hot/parallel_0/Cast/x" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "while_loop/while/one_hot/parallel_0/Cast" - op: "Cast" - input: "while_loop/while/one_hot/parallel_0/Cast/x" - attr { - key: "DstT" - value { - type: DT_INT32 - } - } - attr { - key: "SrcT" - value { - type: DT_FLOAT - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "while_loop/while/one_hot/parallel_0/Cast_1/x" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/while/one_hot/parallel_0/Cast_1" - op: "Cast" - input: "while_loop/while/one_hot/parallel_0/Cast_1/x" - attr { - key: "DstT" - value { - type: DT_INT32 - } - } - attr { - key: "SrcT" - value { - type: DT_FLOAT - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "while_loop/while/one_hot/parallel_0/one_hot/depth" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 4 - } - } - } -} -node { - name: "while_loop/while/one_hot/parallel_0/one_hot" - op: "OneHot" - input: "while_loop/while/one_hot/parallel_0/sub" - input: "while_loop/while/one_hot/parallel_0/one_hot/depth" - input: "while_loop/while/one_hot/parallel_0/Cast" - input: "while_loop/while/one_hot/parallel_0/Cast_1" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "TI" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "axis" - value { - i: -1 - } - } -} -node { - name: "while_loop/while/einsum/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/einsum/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/one_hot/parallel_0/one_hot" - input: "while_loop/while/einsum/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } -} -node { - name: "while_loop/while/einsum/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/einsum/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/einsum/parallel_0/transpose" - input: "while_loop/while/einsum/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 4 - } - } - } - } - } -} -node { - name: "while_loop/while/einsum/parallel_0/ExpandDims_1/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "while_loop/while/einsum/parallel_0/ExpandDims_1" - op: "ExpandDims" - input: "while_loop/while/einsum/parallel_0/ExpandDims" - input: "while_loop/while/einsum/parallel_0/ExpandDims_1/dim" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 4 - } - } - } - } - } -} -node { - name: "while_loop/while/einsum/parallel_0/ExpandDims_2/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/einsum/parallel_0/ExpandDims_2" - op: "ExpandDims" - input: "while_loop/while/einsum/parallel_0/ExpandDims_1" - input: "while_loop/while/einsum/parallel_0/ExpandDims_2/dim" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 4 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/einsum/parallel_0/ExpandDims_2" - input: "while_loop/while/reshape_6/parallel_0/Reshape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 4 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/einsum/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 2 - } - } - } -} -node { - name: "while_loop/while/einsum/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/einsum/parallel_0/Mul" - input: "while_loop/while/einsum/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/reshape_7/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\004\000\000\000\000\002\000\000" - } - } - } -} -node { - name: "while_loop/while/reshape_7/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/reshape_7/parallel_0/Reshape/Enter" - input: "while_loop/while/reshape_7/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 4 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/reshape_7/parallel_0/Reshape/Enter" - op: "Enter" - input: "reshape_1_1/parallel_0/Reshape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 32 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/one_hot_1/parallel_0/sub/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/one_hot_1/parallel_0/sub" - op: "Sub" - input: "while_loop/while/Identity" - input: "while_loop/while/one_hot_1/parallel_0/sub/y" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "while_loop/while/one_hot_1/parallel_0/Cast/x" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "while_loop/while/one_hot_1/parallel_0/Cast" - op: "Cast" - input: "while_loop/while/one_hot_1/parallel_0/Cast/x" - attr { - key: "DstT" - value { - type: DT_INT32 - } - } - attr { - key: "SrcT" - value { - type: DT_FLOAT - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "while_loop/while/one_hot_1/parallel_0/Cast_1/x" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/while/one_hot_1/parallel_0/Cast_1" - op: "Cast" - input: "while_loop/while/one_hot_1/parallel_0/Cast_1/x" - attr { - key: "DstT" - value { - type: DT_INT32 - } - } - attr { - key: "SrcT" - value { - type: DT_FLOAT - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "while_loop/while/one_hot_1/parallel_0/one_hot/depth" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 4 - } - } - } -} -node { - name: "while_loop/while/one_hot_1/parallel_0/one_hot" - op: "OneHot" - input: "while_loop/while/one_hot_1/parallel_0/sub" - input: "while_loop/while/one_hot_1/parallel_0/one_hot/depth" - input: "while_loop/while/one_hot_1/parallel_0/Cast" - input: "while_loop/while/one_hot_1/parallel_0/Cast_1" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "TI" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "axis" - value { - i: -1 - } - } -} -node { - name: "while_loop/while/einsum_1/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/einsum_1/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/one_hot_1/parallel_0/one_hot" - input: "while_loop/while/einsum_1/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } -} -node { - name: "while_loop/while/einsum_1/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/einsum_1/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/einsum_1/parallel_0/transpose" - input: "while_loop/while/einsum_1/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 4 - } - } - } - } - } -} -node { - name: "while_loop/while/einsum_1/parallel_0/ExpandDims_1/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "while_loop/while/einsum_1/parallel_0/ExpandDims_1" - op: "ExpandDims" - input: "while_loop/while/einsum_1/parallel_0/ExpandDims" - input: "while_loop/while/einsum_1/parallel_0/ExpandDims_1/dim" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 4 - } - } - } - } - } -} -node { - name: "while_loop/while/einsum_1/parallel_0/ExpandDims_2/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/einsum_1/parallel_0/ExpandDims_2" - op: "ExpandDims" - input: "while_loop/while/einsum_1/parallel_0/ExpandDims_1" - input: "while_loop/while/einsum_1/parallel_0/ExpandDims_2/dim" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 4 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/einsum_1/parallel_0/ExpandDims_2" - input: "while_loop/while/reshape_7/parallel_0/Reshape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 4 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/einsum_1/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 2 - } - } - } -} -node { - name: "while_loop/while/einsum_1/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/einsum_1/parallel_0/Mul" - input: "while_loop/while/einsum_1/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/reshape_8/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\004\000\000\000\000\002\000\000" - } - } - } -} -node { - name: "while_loop/while/reshape_8/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/reshape_8/parallel_0/Reshape/Enter" - input: "while_loop/while/reshape_8/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 4 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/reshape_8/parallel_0/Reshape/Enter" - op: "Enter" - input: "reshape_2_1/parallel_0/Reshape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 32 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/one_hot_2/parallel_0/sub/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/one_hot_2/parallel_0/sub" - op: "Sub" - input: "while_loop/while/Identity" - input: "while_loop/while/one_hot_2/parallel_0/sub/y" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "while_loop/while/one_hot_2/parallel_0/Cast/x" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "while_loop/while/one_hot_2/parallel_0/Cast" - op: "Cast" - input: "while_loop/while/one_hot_2/parallel_0/Cast/x" - attr { - key: "DstT" - value { - type: DT_INT32 - } - } - attr { - key: "SrcT" - value { - type: DT_FLOAT - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "while_loop/while/one_hot_2/parallel_0/Cast_1/x" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/while/one_hot_2/parallel_0/Cast_1" - op: "Cast" - input: "while_loop/while/one_hot_2/parallel_0/Cast_1/x" - attr { - key: "DstT" - value { - type: DT_INT32 - } - } - attr { - key: "SrcT" - value { - type: DT_FLOAT - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "while_loop/while/one_hot_2/parallel_0/one_hot/depth" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 4 - } - } - } -} -node { - name: "while_loop/while/one_hot_2/parallel_0/one_hot" - op: "OneHot" - input: "while_loop/while/one_hot_2/parallel_0/sub" - input: "while_loop/while/one_hot_2/parallel_0/one_hot/depth" - input: "while_loop/while/one_hot_2/parallel_0/Cast" - input: "while_loop/while/one_hot_2/parallel_0/Cast_1" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "TI" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "axis" - value { - i: -1 - } - } -} -node { - name: "while_loop/while/einsum_2/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/einsum_2/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/one_hot_2/parallel_0/one_hot" - input: "while_loop/while/einsum_2/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } -} -node { - name: "while_loop/while/einsum_2/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/einsum_2/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/einsum_2/parallel_0/transpose" - input: "while_loop/while/einsum_2/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 4 - } - } - } - } - } -} -node { - name: "while_loop/while/einsum_2/parallel_0/ExpandDims_1/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "while_loop/while/einsum_2/parallel_0/ExpandDims_1" - op: "ExpandDims" - input: "while_loop/while/einsum_2/parallel_0/ExpandDims" - input: "while_loop/while/einsum_2/parallel_0/ExpandDims_1/dim" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 4 - } - } - } - } - } -} -node { - name: "while_loop/while/einsum_2/parallel_0/ExpandDims_2/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/einsum_2/parallel_0/ExpandDims_2" - op: "ExpandDims" - input: "while_loop/while/einsum_2/parallel_0/ExpandDims_1" - input: "while_loop/while/einsum_2/parallel_0/ExpandDims_2/dim" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 4 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/einsum_2/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/einsum_2/parallel_0/ExpandDims_2" - input: "while_loop/while/reshape_8/parallel_0/Reshape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 4 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/einsum_2/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 2 - } - } - } -} -node { - name: "while_loop/while/einsum_2/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/einsum_2/parallel_0/Mul" - input: "while_loop/while/einsum_2/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/reshape_9/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\004\000\000\000\000\002\000\000" - } - } - } -} -node { - name: "while_loop/while/reshape_9/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/reshape_9/parallel_0/Reshape/Enter" - input: "while_loop/while/reshape_9/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 4 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/reshape_9/parallel_0/Reshape/Enter" - op: "Enter" - input: "reshape_3_1/parallel_0/Reshape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 32 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/one_hot_3/parallel_0/sub/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/one_hot_3/parallel_0/sub" - op: "Sub" - input: "while_loop/while/Identity" - input: "while_loop/while/one_hot_3/parallel_0/sub/y" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "while_loop/while/one_hot_3/parallel_0/Cast/x" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "while_loop/while/one_hot_3/parallel_0/Cast" - op: "Cast" - input: "while_loop/while/one_hot_3/parallel_0/Cast/x" - attr { - key: "DstT" - value { - type: DT_INT32 - } - } - attr { - key: "SrcT" - value { - type: DT_FLOAT - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "while_loop/while/one_hot_3/parallel_0/Cast_1/x" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/while/one_hot_3/parallel_0/Cast_1" - op: "Cast" - input: "while_loop/while/one_hot_3/parallel_0/Cast_1/x" - attr { - key: "DstT" - value { - type: DT_INT32 - } - } - attr { - key: "SrcT" - value { - type: DT_FLOAT - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "while_loop/while/one_hot_3/parallel_0/one_hot/depth" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 4 - } - } - } -} -node { - name: "while_loop/while/one_hot_3/parallel_0/one_hot" - op: "OneHot" - input: "while_loop/while/one_hot_3/parallel_0/sub" - input: "while_loop/while/one_hot_3/parallel_0/one_hot/depth" - input: "while_loop/while/one_hot_3/parallel_0/Cast" - input: "while_loop/while/one_hot_3/parallel_0/Cast_1" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "TI" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "axis" - value { - i: -1 - } - } -} -node { - name: "while_loop/while/einsum_3/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/einsum_3/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/one_hot_3/parallel_0/one_hot" - input: "while_loop/while/einsum_3/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } -} -node { - name: "while_loop/while/einsum_3/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/einsum_3/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/einsum_3/parallel_0/transpose" - input: "while_loop/while/einsum_3/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 4 - } - } - } - } - } -} -node { - name: "while_loop/while/einsum_3/parallel_0/ExpandDims_1/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "while_loop/while/einsum_3/parallel_0/ExpandDims_1" - op: "ExpandDims" - input: "while_loop/while/einsum_3/parallel_0/ExpandDims" - input: "while_loop/while/einsum_3/parallel_0/ExpandDims_1/dim" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 4 - } - } - } - } - } -} -node { - name: "while_loop/while/einsum_3/parallel_0/ExpandDims_2/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/einsum_3/parallel_0/ExpandDims_2" - op: "ExpandDims" - input: "while_loop/while/einsum_3/parallel_0/ExpandDims_1" - input: "while_loop/while/einsum_3/parallel_0/ExpandDims_2/dim" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 4 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/einsum_3/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/einsum_3/parallel_0/ExpandDims_2" - input: "while_loop/while/reshape_9/parallel_0/Reshape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 4 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/einsum_3/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 2 - } - } - } -} -node { - name: "while_loop/while/einsum_3/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/einsum_3/parallel_0/Mul" - input: "while_loop/while/einsum_3/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/reshape_10/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\004\000\000\000\000\002\000\000" - } - } - } -} -node { - name: "while_loop/while/reshape_10/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/reshape_10/parallel_0/Reshape/Enter" - input: "while_loop/while/reshape_10/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 4 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/reshape_10/parallel_0/Reshape/Enter" - op: "Enter" - input: "reshape_4_1/parallel_0/Reshape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 32 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/one_hot_4/parallel_0/sub/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/one_hot_4/parallel_0/sub" - op: "Sub" - input: "while_loop/while/Identity" - input: "while_loop/while/one_hot_4/parallel_0/sub/y" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "while_loop/while/one_hot_4/parallel_0/Cast/x" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "while_loop/while/one_hot_4/parallel_0/Cast" - op: "Cast" - input: "while_loop/while/one_hot_4/parallel_0/Cast/x" - attr { - key: "DstT" - value { - type: DT_INT32 - } - } - attr { - key: "SrcT" - value { - type: DT_FLOAT - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "while_loop/while/one_hot_4/parallel_0/Cast_1/x" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/while/one_hot_4/parallel_0/Cast_1" - op: "Cast" - input: "while_loop/while/one_hot_4/parallel_0/Cast_1/x" - attr { - key: "DstT" - value { - type: DT_INT32 - } - } - attr { - key: "SrcT" - value { - type: DT_FLOAT - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "while_loop/while/one_hot_4/parallel_0/one_hot/depth" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 4 - } - } - } -} -node { - name: "while_loop/while/one_hot_4/parallel_0/one_hot" - op: "OneHot" - input: "while_loop/while/one_hot_4/parallel_0/sub" - input: "while_loop/while/one_hot_4/parallel_0/one_hot/depth" - input: "while_loop/while/one_hot_4/parallel_0/Cast" - input: "while_loop/while/one_hot_4/parallel_0/Cast_1" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "TI" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "axis" - value { - i: -1 - } - } -} -node { - name: "while_loop/while/einsum_4/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/einsum_4/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/one_hot_4/parallel_0/one_hot" - input: "while_loop/while/einsum_4/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } -} -node { - name: "while_loop/while/einsum_4/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/einsum_4/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/einsum_4/parallel_0/transpose" - input: "while_loop/while/einsum_4/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 4 - } - } - } - } - } -} -node { - name: "while_loop/while/einsum_4/parallel_0/ExpandDims_1/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "while_loop/while/einsum_4/parallel_0/ExpandDims_1" - op: "ExpandDims" - input: "while_loop/while/einsum_4/parallel_0/ExpandDims" - input: "while_loop/while/einsum_4/parallel_0/ExpandDims_1/dim" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 4 - } - } - } - } - } -} -node { - name: "while_loop/while/einsum_4/parallel_0/ExpandDims_2/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/einsum_4/parallel_0/ExpandDims_2" - op: "ExpandDims" - input: "while_loop/while/einsum_4/parallel_0/ExpandDims_1" - input: "while_loop/while/einsum_4/parallel_0/ExpandDims_2/dim" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 4 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/einsum_4/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/einsum_4/parallel_0/ExpandDims_2" - input: "while_loop/while/reshape_10/parallel_0/Reshape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 4 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/einsum_4/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 2 - } - } - } -} -node { - name: "while_loop/while/einsum_4/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/einsum_4/parallel_0/Mul" - input: "while_loop/while/einsum_4/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/reshape_11/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\004\000\000\000\000\002\000\000" - } - } - } -} -node { - name: "while_loop/while/reshape_11/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/reshape_11/parallel_0/Reshape/Enter" - input: "while_loop/while/reshape_11/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 4 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/reshape_11/parallel_0/Reshape/Enter" - op: "Enter" - input: "reshape_5_1/parallel_0/Reshape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 32 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/one_hot_5/parallel_0/sub/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/one_hot_5/parallel_0/sub" - op: "Sub" - input: "while_loop/while/Identity" - input: "while_loop/while/one_hot_5/parallel_0/sub/y" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "while_loop/while/one_hot_5/parallel_0/Cast/x" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "while_loop/while/one_hot_5/parallel_0/Cast" - op: "Cast" - input: "while_loop/while/one_hot_5/parallel_0/Cast/x" - attr { - key: "DstT" - value { - type: DT_INT32 - } - } - attr { - key: "SrcT" - value { - type: DT_FLOAT - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "while_loop/while/one_hot_5/parallel_0/Cast_1/x" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/while/one_hot_5/parallel_0/Cast_1" - op: "Cast" - input: "while_loop/while/one_hot_5/parallel_0/Cast_1/x" - attr { - key: "DstT" - value { - type: DT_INT32 - } - } - attr { - key: "SrcT" - value { - type: DT_FLOAT - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "while_loop/while/one_hot_5/parallel_0/one_hot/depth" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 4 - } - } - } -} -node { - name: "while_loop/while/one_hot_5/parallel_0/one_hot" - op: "OneHot" - input: "while_loop/while/one_hot_5/parallel_0/sub" - input: "while_loop/while/one_hot_5/parallel_0/one_hot/depth" - input: "while_loop/while/one_hot_5/parallel_0/Cast" - input: "while_loop/while/one_hot_5/parallel_0/Cast_1" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "TI" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "axis" - value { - i: -1 - } - } -} -node { - name: "while_loop/while/einsum_5/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/einsum_5/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/one_hot_5/parallel_0/one_hot" - input: "while_loop/while/einsum_5/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } -} -node { - name: "while_loop/while/einsum_5/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/einsum_5/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/einsum_5/parallel_0/transpose" - input: "while_loop/while/einsum_5/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 4 - } - } - } - } - } -} -node { - name: "while_loop/while/einsum_5/parallel_0/ExpandDims_1/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "while_loop/while/einsum_5/parallel_0/ExpandDims_1" - op: "ExpandDims" - input: "while_loop/while/einsum_5/parallel_0/ExpandDims" - input: "while_loop/while/einsum_5/parallel_0/ExpandDims_1/dim" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 4 - } - } - } - } - } -} -node { - name: "while_loop/while/einsum_5/parallel_0/ExpandDims_2/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/einsum_5/parallel_0/ExpandDims_2" - op: "ExpandDims" - input: "while_loop/while/einsum_5/parallel_0/ExpandDims_1" - input: "while_loop/while/einsum_5/parallel_0/ExpandDims_2/dim" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 4 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/einsum_5/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/einsum_5/parallel_0/ExpandDims_2" - input: "while_loop/while/reshape_11/parallel_0/Reshape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 4 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/einsum_5/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 2 - } - } - } -} -node { - name: "while_loop/while/einsum_5/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/einsum_5/parallel_0/Mul" - input: "while_loop/while/einsum_5/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/range/range/range/start" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/range/range/range/limit" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 512 - } - } - } -} -node { - name: "while_loop/while/range/range/range/delta" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "while_loop/while/range/range/range" - op: "Range" - input: "while_loop/while/range/range/range/start" - input: "while_loop/while/range/range/range/limit" - input: "while_loop/while/range/range/range/delta" - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/one_hot/parallel_0/sub/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/encoder/one_hot/parallel_0/sub" - op: "Sub" - input: "while_loop/while/einsum/parallel_0/Sum" - input: "while_loop/while/encoder/one_hot/parallel_0/sub/y" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/one_hot/parallel_0/Cast/x" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "while_loop/while/encoder/one_hot/parallel_0/Cast_1/x" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/while/encoder/one_hot/parallel_0/one_hot/depth" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 32128 - } - } - } -} -node { - name: "while_loop/while/encoder/one_hot/parallel_0/one_hot" - op: "OneHot" - input: "while_loop/while/encoder/one_hot/parallel_0/sub" - input: "while_loop/while/encoder/one_hot/parallel_0/one_hot/depth" - input: "while_loop/while/encoder/one_hot/parallel_0/Cast/x" - input: "while_loop/while/encoder/one_hot/parallel_0/Cast_1/x" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "TI" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32128 - } - } - } - } - } - attr { - key: "axis" - value { - i: -1 - } - } -} -node { - name: "while_loop/while/encoder/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/encoder/one_hot/parallel_0/one_hot" - input: "while_loop/while/encoder/einsum/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,de->abce" - } - } -} -node { - name: "while_loop/while/encoder/einsum/parallel_0/einsum/Einsum/Enter" - op: "Enter" - input: "shared/embedding_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/rms_norm/square/parallel_0/Square" - op: "Square" - input: "while_loop/while/encoder/einsum/parallel_0/einsum/Einsum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/rms_norm/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/rms_norm/reduce_mean/reduce/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/encoder/block_000/layer_000/rms_norm/square/parallel_0/Square" - input: "while_loop/while/encoder/block_000/layer_000/rms_norm/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/rms_norm/reduce_mean/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.03125 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/rms_norm/reduce_mean/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/encoder/block_000/layer_000/rms_norm/reduce_mean/reduce/parallel_0/Sum" - input: "while_loop/while/encoder/block_000/layer_000/rms_norm/reduce_mean/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/scalar_add/parallel_0/add/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 9.999999974752427e-07 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/scalar_add/parallel_0/add" - op: "AddV2" - input: "while_loop/while/encoder/block_000/layer_000/rms_norm/reduce_mean/scalar_mul/parallel_0/mul" - input: "while_loop/while/encoder/block_000/layer_000/scalar_add/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/rsqrt/parallel_0/Rsqrt" - op: "Rsqrt" - input: "while_loop/while/encoder/block_000/layer_000/scalar_add/parallel_0/add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/einsum/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/einsum/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/encoder/block_000/layer_000/rsqrt/parallel_0/Rsqrt" - input: "while_loop/while/encoder/block_000/layer_000/einsum/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/einsum/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/einsum/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/encoder/block_000/layer_000/einsum/parallel_0/transpose" - input: "while_loop/while/encoder/block_000/layer_000/einsum/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/einsum/parallel_0/einsum/Einsum" - input: "while_loop/while/encoder/block_000/layer_000/einsum/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/einsum_1/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/einsum_1/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/encoder/block_000/layer_000/einsum_1/parallel_0/transpose/Enter" - input: "while_loop/while/encoder/block_000/layer_000/einsum_1/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/einsum_1/parallel_0/transpose/Enter" - op: "Enter" - input: "encoder/block_000/layer_000/rms_norm/scale_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/einsum_1/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/einsum_1/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/encoder/block_000/layer_000/einsum_1/parallel_0/transpose" - input: "while_loop/while/encoder/block_000/layer_000/einsum_1/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/einsum_1/parallel_0/ExpandDims_1/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/einsum_1/parallel_0/ExpandDims_1" - op: "ExpandDims" - input: "while_loop/while/encoder/block_000/layer_000/einsum_1/parallel_0/ExpandDims" - input: "while_loop/while/encoder/block_000/layer_000/einsum_1/parallel_0/ExpandDims_1/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/einsum_1/parallel_0/ExpandDims_2/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 2 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/einsum_1/parallel_0/ExpandDims_2" - op: "ExpandDims" - input: "while_loop/while/encoder/block_000/layer_000/einsum_1/parallel_0/ExpandDims_1" - input: "while_loop/while/encoder/block_000/layer_000/einsum_1/parallel_0/ExpandDims_2/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_000/layer_000/einsum/parallel_0/Mul" - input: "while_loop/while/encoder/block_000/layer_000/einsum_1/parallel_0/ExpandDims_2" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/binary_op/parallel_0_1/NotEqual" - op: "NotEqual" - input: "while_loop/while/einsum_1/parallel_0/Sum" - input: "while_loop/while/encoder/block_000/layer_000/binary_op/parallel_0_1/NotEqual/Enter" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "incompatible_shape_error" - value { - b: true - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/binary_op/parallel_0_1/NotEqual/Enter" - op: "Enter" - input: "encoder/block_000/layer_000/Const" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/cast/parallel_0/Cast" - op: "Cast" - input: "while_loop/while/encoder/block_000/layer_000/binary_op/parallel_0_1/NotEqual" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BOOL - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/einsum_2/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/einsum_2/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/encoder/block_000/layer_000/cast/parallel_0/Cast" - input: "while_loop/while/encoder/block_000/layer_000/einsum_2/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/einsum_2/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/einsum_2/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/encoder/block_000/layer_000/einsum_2/parallel_0/transpose" - input: "while_loop/while/encoder/block_000/layer_000/einsum_2/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/einsum_2/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_000/layer_000/einsum_1/parallel_0/Mul" - input: "while_loop/while/encoder/block_000/layer_000/einsum_2/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/encoder/block_000/layer_000/einsum_2/parallel_0/Mul" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/einsum/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,de->abce" - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/einsum/parallel_0/einsum/Einsum/Enter" - op: "Enter" - input: "encoder/block_000/layer_000/SelfAttention/q_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/reshape/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 5 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 5 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000\002\000\000\000@\000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/reshape/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/einsum/parallel_0/einsum/Einsum" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/reshape/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/reshape_1/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000 \000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/reshape_1/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/encoder/block_000/layer_000/einsum_2/parallel_0/Mul" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/reshape_1/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/einsum_1/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/reshape_1/parallel_0/Reshape" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/einsum_1/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,de->abce" - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/einsum_1/parallel_0/einsum/Einsum/Enter" - op: "Enter" - input: "encoder/block_000/layer_000/SelfAttention/k_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/reshape_2/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 5 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 5 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000\002\000\000\000@\000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/reshape_2/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/einsum_1/parallel_0/einsum/Einsum" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/reshape_2/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/einsum_2/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/reshape_1/parallel_0/Reshape" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/einsum_2/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,de->abce" - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/einsum_2/parallel_0/einsum/Einsum/Enter" - op: "Enter" - input: "encoder/block_000/layer_000/SelfAttention/v_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/reshape_3/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 5 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 5 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000\002\000\000\000@\000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/reshape_3/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/einsum_2/parallel_0/einsum/Einsum" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/reshape_3/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/reshape_4/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 512 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/reshape_4/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/range/range/range" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/reshape_4/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/negative/parallel_0/Neg" - op: "Neg" - input: "while_loop/while/range/range/range" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/add/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/add/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/reshape_4/parallel_0/Reshape" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/add/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/add/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/add/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/add/parallel_0/transpose" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/add/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/add/parallel_0_1/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/add/parallel_0_1/transpose" - op: "Transpose" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/negative/parallel_0/Neg" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/add/parallel_0_1/transpose/perm" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/add/parallel_0_1/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/add/parallel_0_1/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/add/parallel_0_1/transpose" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/add/parallel_0_1/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/add/parallel_0_2/Add" - op: "Add" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/add/parallel_0/ExpandDims" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/add/parallel_0_1/ExpandDims" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/reshape_5/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/reshape_5/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/einsum_1/parallel_0/Sum" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/reshape_5/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/binary_op/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/binary_op/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/einsum_1/parallel_0/Sum" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/binary_op/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/binary_op/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/binary_op/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/binary_op/parallel_0/transpose" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/binary_op/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/binary_op/parallel_0_1/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/binary_op/parallel_0_1/transpose" - op: "Transpose" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/reshape_5/parallel_0/Reshape" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/binary_op/parallel_0_1/transpose/perm" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/binary_op/parallel_0_1/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 2 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/binary_op/parallel_0_1/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/binary_op/parallel_0_1/transpose" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/binary_op/parallel_0_1/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 1 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/binary_op/parallel_0_2/Equal" - op: "Equal" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/binary_op/parallel_0/ExpandDims" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/binary_op/parallel_0_1/ExpandDims" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "incompatible_shape_error" - value { - b: true - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/logical_not/parallel_0/LogicalNot" - op: "LogicalNot" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/binary_op/parallel_0_2/Equal" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/cast/parallel_0/Cast" - op: "Cast" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/logical_not/parallel_0/LogicalNot" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BOOL - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: -1000000000.0 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/cast/parallel_0/Cast" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/negative_1/parallel_0/Neg" - op: "Neg" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/add/parallel_0_2/Add" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/binary_op_1/parallel_0_1/Less" - op: "Less" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/negative_1/parallel_0/Neg" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/binary_op_1/parallel_0_1/Less/Enter" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/binary_op_1/parallel_0_1/Less/Enter" - op: "Enter" - input: "encoder/block_000/layer_000/SelfAttention/Const" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/to_int32/parallel_0/Cast" - op: "Cast" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/binary_op_1/parallel_0_1/Less" - attr { - key: "DstT" - value { - type: DT_INT32 - } - } - attr { - key: "SrcT" - value { - type: DT_BOOL - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/scalar_mul_1/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 16 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/scalar_mul_1/parallel_0/mul" - op: "Mul" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/to_int32/parallel_0/Cast" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/scalar_mul_1/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/scalar_add/parallel_0/add/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/scalar_add/parallel_0/add" - op: "AddV2" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/scalar_mul_1/parallel_0/mul" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/scalar_add/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/sign/parallel_0/Sign" - op: "Sign" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/negative_1/parallel_0/Neg" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/einsum_3/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/negative_1/parallel_0/Neg" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/sign/parallel_0/Sign" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/binary_op_2/parallel_0_1/Less" - op: "Less" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/einsum_3/parallel_0/Mul" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/binary_op_2/parallel_0_1/Less/Enter" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/binary_op_2/parallel_0_1/Less/Enter" - op: "Enter" - input: "encoder/block_000/layer_000/SelfAttention/Const_1" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/to_float/parallel_0/Cast" - op: "Cast" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/einsum_3/parallel_0/Mul" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_INT32 - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/scalar_mul_2/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.125 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/scalar_mul_2/parallel_0/mul" - op: "Mul" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/to_float/parallel_0/Cast" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/scalar_mul_2/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/log/parallel_0/Log" - op: "Log" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/scalar_mul_2/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/scalar_mul_3/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.3606737554073334 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/scalar_mul_3/parallel_0/mul" - op: "Mul" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/log/parallel_0/Log" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/scalar_mul_3/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/scalar_mul_4/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 8.0 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/scalar_mul_4/parallel_0/mul" - op: "Mul" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/scalar_mul_3/parallel_0/mul" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/scalar_mul_4/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/to_int32_1/parallel_0/Cast" - op: "Cast" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/scalar_mul_4/parallel_0/mul" - attr { - key: "DstT" - value { - type: DT_INT32 - } - } - attr { - key: "SrcT" - value { - type: DT_FLOAT - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/scalar_add_1/parallel_0/add/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 8 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/scalar_add_1/parallel_0/add" - op: "AddV2" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/to_int32_1/parallel_0/Cast" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/scalar_add_1/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/add_1/parallel_0_1/Minimum" - op: "Minimum" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/scalar_add_1/parallel_0/add" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/add_1/parallel_0_1/Minimum/Enter" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/add_1/parallel_0_1/Minimum/Enter" - op: "Enter" - input: "encoder/block_000/layer_000/SelfAttention/minimum/Const" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/cast_1/parallel_0/Cast" - op: "Cast" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/binary_op_2/parallel_0_1/Less" - attr { - key: "DstT" - value { - type: DT_INT32 - } - } - attr { - key: "SrcT" - value { - type: DT_BOOL - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/einsum_4/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/einsum_3/parallel_0/Mul" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/cast_1/parallel_0/Cast" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/logical_not_1/parallel_0/LogicalNot" - op: "LogicalNot" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/binary_op_2/parallel_0_1/Less" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/cast_2/parallel_0/Cast" - op: "Cast" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/logical_not_1/parallel_0/LogicalNot" - attr { - key: "DstT" - value { - type: DT_INT32 - } - } - attr { - key: "SrcT" - value { - type: DT_BOOL - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/einsum_5/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/add_1/parallel_0_1/Minimum" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/cast_2/parallel_0/Cast" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/add_2/parallel_0/Add" - op: "Add" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/einsum_4/parallel_0/Mul" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/einsum_5/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/add_3/parallel_0/Add" - op: "Add" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/scalar_add/parallel_0/add" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/add_2/parallel_0/Add" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/one_hot/parallel_0/sub/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/one_hot/parallel_0/sub" - op: "Sub" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/add_3/parallel_0/Add" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/one_hot/parallel_0/sub/y" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/one_hot/parallel_0/Cast/x" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/one_hot/parallel_0/Cast_1/x" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/one_hot/parallel_0/one_hot/depth" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 32 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/one_hot/parallel_0/one_hot" - op: "OneHot" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/one_hot/parallel_0/sub" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/one_hot/parallel_0/one_hot/depth" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/one_hot/parallel_0/Cast/x" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/one_hot/parallel_0/Cast_1/x" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "TI" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "axis" - value { - i: -1 - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/einsum_6/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/one_hot/parallel_0/one_hot" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/einsum_6/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abc,dc->abd" - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/einsum_6/parallel_0/einsum/Einsum/Enter" - op: "Enter" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/add_4/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000\003\000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/add_4/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/scalar_mul/parallel_0/mul" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/add_4/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/add_4/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 4 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/add_4/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/add_4/parallel_0/transpose" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/add_4/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/add_4/parallel_0_1/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\001\000\000\000\000\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/add_4/parallel_0_1/transpose" - op: "Transpose" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/einsum_6/parallel_0/einsum/Einsum" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/add_4/parallel_0_1/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/add_4/parallel_0_1/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/add_4/parallel_0_1/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/add_4/parallel_0_1/transpose" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/add_4/parallel_0_1/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 512 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/add_4/parallel_0_1/ExpandDims_1/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/add_4/parallel_0_1/ExpandDims_1" - op: "ExpandDims" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/add_4/parallel_0_1/ExpandDims" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/add_4/parallel_0_1/ExpandDims_1/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 512 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/add_4/parallel_0_2/Add" - op: "Add" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/add_4/parallel_0/ExpandDims" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/add_4/parallel_0_1/ExpandDims_1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/einsum_7/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/reshape/parallel_0/Reshape" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/reshape_2/parallel_0/Reshape" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcde,abfde->abcdf" - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/add_5/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 5 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 5 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000\004\000\000\000\003\000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/add_5/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/add_4/parallel_0_2/Add" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/add_5/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/add_5/parallel_0_1/Add" - op: "Add" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/einsum_7/parallel_0/einsum/Einsum" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/add_5/parallel_0/transpose" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce_max/parallel_0/Max/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 4 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce_max/parallel_0/Max" - op: "Max" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/add_5/parallel_0_1/Add" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce_max/parallel_0/Max/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/negative/parallel_0/Neg" - op: "Neg" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce_max/parallel_0/Max" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/add/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000\003\000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/add/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/negative/parallel_0/Neg" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/add/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/add/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 4 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/add/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/add/parallel_0/transpose" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/add/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/add/parallel_0_1/Add" - op: "Add" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/add_5/parallel_0_1/Add" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/add/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/exp/parallel_0/Exp" - op: "Exp" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/add/parallel_0_1/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 4 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/exp/parallel_0/Exp" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/log/parallel_0/Log" - op: "Log" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/parallel_0/Sum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/add_1/parallel_0/Add" - op: "Add" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/log/parallel_0/Log" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce_max/parallel_0/Max" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/negative/parallel_0/Neg" - op: "Neg" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/add_1/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/add/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000\003\000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/add/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/negative/parallel_0/Neg" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/add/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/add/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 4 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/add/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/add/parallel_0/transpose" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/add/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/add/parallel_0_1/Add" - op: "Add" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/add_5/parallel_0_1/Add" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/add/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/exp/parallel_0/Exp" - op: "Exp" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/add/parallel_0_1/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/einsum_8/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/exp/parallel_0/Exp" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/reshape_3/parallel_0/Reshape" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcde,abedf->abcdf" - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/reshape_6/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 5 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 5 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000\002\000\000\000@\000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/reshape_6/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/einsum_8/parallel_0/einsum/Einsum" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/reshape_6/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/reshape_7/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000\200\000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/reshape_7/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/reshape_6/parallel_0/Reshape" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/reshape_7/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/einsum_9/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/reshape_7/parallel_0/Reshape" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/einsum_9/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,de->abce" - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/einsum_9/parallel_0/einsum/Einsum/Enter" - op: "Enter" - input: "encoder/block_000/layer_000/SelfAttention/o_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/einsum_9/parallel_0/einsum/Einsum" - input: "while_loop/while/encoder/einsum/parallel_0/einsum/Einsum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/rms_norm/square/parallel_0/Square" - op: "Square" - input: "while_loop/while/encoder/block_000/layer_000/add/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/rms_norm/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/rms_norm/reduce_mean/reduce/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/encoder/block_000/layer_001/rms_norm/square/parallel_0/Square" - input: "while_loop/while/encoder/block_000/layer_001/rms_norm/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/rms_norm/reduce_mean/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.03125 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/rms_norm/reduce_mean/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/encoder/block_000/layer_001/rms_norm/reduce_mean/reduce/parallel_0/Sum" - input: "while_loop/while/encoder/block_000/layer_001/rms_norm/reduce_mean/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/scalar_add/parallel_0/add/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 9.999999974752427e-07 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/scalar_add/parallel_0/add" - op: "AddV2" - input: "while_loop/while/encoder/block_000/layer_001/rms_norm/reduce_mean/scalar_mul/parallel_0/mul" - input: "while_loop/while/encoder/block_000/layer_001/scalar_add/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/rsqrt/parallel_0/Rsqrt" - op: "Rsqrt" - input: "while_loop/while/encoder/block_000/layer_001/scalar_add/parallel_0/add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/einsum/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/einsum/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/encoder/block_000/layer_001/rsqrt/parallel_0/Rsqrt" - input: "while_loop/while/encoder/block_000/layer_001/einsum/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/einsum/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/einsum/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/encoder/block_000/layer_001/einsum/parallel_0/transpose" - input: "while_loop/while/encoder/block_000/layer_001/einsum/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_000/layer_000/add/parallel_0/Add" - input: "while_loop/while/encoder/block_000/layer_001/einsum/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/einsum_1/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/einsum_1/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/encoder/block_000/layer_001/einsum_1/parallel_0/transpose/Enter" - input: "while_loop/while/encoder/block_000/layer_001/einsum_1/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/einsum_1/parallel_0/transpose/Enter" - op: "Enter" - input: "encoder/block_000/layer_001/rms_norm/scale_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/einsum_1/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/einsum_1/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/encoder/block_000/layer_001/einsum_1/parallel_0/transpose" - input: "while_loop/while/encoder/block_000/layer_001/einsum_1/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/einsum_1/parallel_0/ExpandDims_1/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/einsum_1/parallel_0/ExpandDims_1" - op: "ExpandDims" - input: "while_loop/while/encoder/block_000/layer_001/einsum_1/parallel_0/ExpandDims" - input: "while_loop/while/encoder/block_000/layer_001/einsum_1/parallel_0/ExpandDims_1/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/einsum_1/parallel_0/ExpandDims_2/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 2 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/einsum_1/parallel_0/ExpandDims_2" - op: "ExpandDims" - input: "while_loop/while/encoder/block_000/layer_001/einsum_1/parallel_0/ExpandDims_1" - input: "while_loop/while/encoder/block_000/layer_001/einsum_1/parallel_0/ExpandDims_2/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_000/layer_001/einsum/parallel_0/Mul" - input: "while_loop/while/encoder/block_000/layer_001/einsum_1/parallel_0/ExpandDims_2" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/binary_op/parallel_0_1/NotEqual" - op: "NotEqual" - input: "while_loop/while/einsum_1/parallel_0/Sum" - input: "while_loop/while/encoder/block_000/layer_001/binary_op/parallel_0_1/NotEqual/Enter" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "incompatible_shape_error" - value { - b: true - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/binary_op/parallel_0_1/NotEqual/Enter" - op: "Enter" - input: "encoder/block_000/layer_001/Const" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/cast/parallel_0/Cast" - op: "Cast" - input: "while_loop/while/encoder/block_000/layer_001/binary_op/parallel_0_1/NotEqual" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BOOL - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/einsum_2/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/einsum_2/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/encoder/block_000/layer_001/cast/parallel_0/Cast" - input: "while_loop/while/encoder/block_000/layer_001/einsum_2/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/einsum_2/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/einsum_2/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/encoder/block_000/layer_001/einsum_2/parallel_0/transpose" - input: "while_loop/while/encoder/block_000/layer_001/einsum_2/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/einsum_2/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_000/layer_001/einsum_1/parallel_0/Mul" - input: "while_loop/while/encoder/block_000/layer_001/einsum_2/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/encoder/block_000/layer_001/einsum_2/parallel_0/Mul" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/einsum/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,de->abce" - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/einsum/parallel_0/einsum/Einsum/Enter" - op: "Enter" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.044714998453855515 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/einsum/parallel_0/einsum/Einsum" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/scalar_mul/parallel_0/mul" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/einsum/parallel_0/einsum/Einsum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/einsum_2/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/einsum_1/parallel_0/Mul" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/einsum/parallel_0/einsum/Einsum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/einsum/parallel_0/einsum/Einsum" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/einsum_2/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/scalar_mul_1/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.7978845834732056 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/scalar_mul_1/parallel_0/mul" - op: "Mul" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/add/parallel_0/Add" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/scalar_mul_1/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/tanh/parallel_0/Tanh" - op: "Tanh" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/scalar_mul_1/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/scalar_add/parallel_0/add/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/scalar_add/parallel_0/add" - op: "AddV2" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/tanh/parallel_0/Tanh" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/scalar_add/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/scalar_mul_2/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.5 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/scalar_mul_2/parallel_0/mul" - op: "Mul" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/scalar_add/parallel_0/add" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/scalar_mul_2/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/einsum_3/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/einsum/parallel_0/einsum/Einsum" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/scalar_mul_2/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_1/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/encoder/block_000/layer_001/einsum_2/parallel_0/Mul" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_1/einsum/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,de->abce" - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_1/einsum/parallel_0/einsum/Einsum/Enter" - op: "Enter" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/einsum_3/parallel_0/Mul" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_1/einsum/parallel_0/einsum/Einsum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wo/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/einsum/parallel_0/Mul" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wo/einsum/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,de->abce" - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wo/einsum/parallel_0/einsum/Einsum/Enter" - op: "Enter" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wo/einsum/parallel_0/einsum/Einsum" - input: "while_loop/while/encoder/block_000/layer_000/add/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/rms_norm/square/parallel_0/Square" - op: "Square" - input: "while_loop/while/encoder/block_000/layer_001/add/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/rms_norm/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/rms_norm/reduce_mean/reduce/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/encoder/block_001/layer_000/rms_norm/square/parallel_0/Square" - input: "while_loop/while/encoder/block_001/layer_000/rms_norm/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/rms_norm/reduce_mean/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.03125 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/rms_norm/reduce_mean/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/encoder/block_001/layer_000/rms_norm/reduce_mean/reduce/parallel_0/Sum" - input: "while_loop/while/encoder/block_001/layer_000/rms_norm/reduce_mean/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/scalar_add/parallel_0/add/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 9.999999974752427e-07 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/scalar_add/parallel_0/add" - op: "AddV2" - input: "while_loop/while/encoder/block_001/layer_000/rms_norm/reduce_mean/scalar_mul/parallel_0/mul" - input: "while_loop/while/encoder/block_001/layer_000/scalar_add/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/rsqrt/parallel_0/Rsqrt" - op: "Rsqrt" - input: "while_loop/while/encoder/block_001/layer_000/scalar_add/parallel_0/add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/einsum/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/einsum/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/encoder/block_001/layer_000/rsqrt/parallel_0/Rsqrt" - input: "while_loop/while/encoder/block_001/layer_000/einsum/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/einsum/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/einsum/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/encoder/block_001/layer_000/einsum/parallel_0/transpose" - input: "while_loop/while/encoder/block_001/layer_000/einsum/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_000/layer_001/add/parallel_0/Add" - input: "while_loop/while/encoder/block_001/layer_000/einsum/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/einsum_1/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/einsum_1/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/encoder/block_001/layer_000/einsum_1/parallel_0/transpose/Enter" - input: "while_loop/while/encoder/block_001/layer_000/einsum_1/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/einsum_1/parallel_0/transpose/Enter" - op: "Enter" - input: "encoder/block_001/layer_000/rms_norm/scale_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/einsum_1/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/einsum_1/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/encoder/block_001/layer_000/einsum_1/parallel_0/transpose" - input: "while_loop/while/encoder/block_001/layer_000/einsum_1/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/einsum_1/parallel_0/ExpandDims_1/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/einsum_1/parallel_0/ExpandDims_1" - op: "ExpandDims" - input: "while_loop/while/encoder/block_001/layer_000/einsum_1/parallel_0/ExpandDims" - input: "while_loop/while/encoder/block_001/layer_000/einsum_1/parallel_0/ExpandDims_1/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/einsum_1/parallel_0/ExpandDims_2/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 2 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/einsum_1/parallel_0/ExpandDims_2" - op: "ExpandDims" - input: "while_loop/while/encoder/block_001/layer_000/einsum_1/parallel_0/ExpandDims_1" - input: "while_loop/while/encoder/block_001/layer_000/einsum_1/parallel_0/ExpandDims_2/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_001/layer_000/einsum/parallel_0/Mul" - input: "while_loop/while/encoder/block_001/layer_000/einsum_1/parallel_0/ExpandDims_2" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/binary_op/parallel_0_1/NotEqual" - op: "NotEqual" - input: "while_loop/while/einsum_1/parallel_0/Sum" - input: "while_loop/while/encoder/block_001/layer_000/binary_op/parallel_0_1/NotEqual/Enter" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "incompatible_shape_error" - value { - b: true - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/binary_op/parallel_0_1/NotEqual/Enter" - op: "Enter" - input: "encoder/block_001/layer_000/Const" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/cast/parallel_0/Cast" - op: "Cast" - input: "while_loop/while/encoder/block_001/layer_000/binary_op/parallel_0_1/NotEqual" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BOOL - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/einsum_2/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/einsum_2/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/encoder/block_001/layer_000/cast/parallel_0/Cast" - input: "while_loop/while/encoder/block_001/layer_000/einsum_2/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/einsum_2/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/einsum_2/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/encoder/block_001/layer_000/einsum_2/parallel_0/transpose" - input: "while_loop/while/encoder/block_001/layer_000/einsum_2/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/einsum_2/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_001/layer_000/einsum_1/parallel_0/Mul" - input: "while_loop/while/encoder/block_001/layer_000/einsum_2/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/encoder/block_001/layer_000/einsum_2/parallel_0/Mul" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/einsum/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,de->abce" - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/einsum/parallel_0/einsum/Einsum/Enter" - op: "Enter" - input: "encoder/block_001/layer_000/SelfAttention/q_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/reshape/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 5 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 5 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000\002\000\000\000@\000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/reshape/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/einsum/parallel_0/einsum/Einsum" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/reshape/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/reshape_1/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000 \000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/reshape_1/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/encoder/block_001/layer_000/einsum_2/parallel_0/Mul" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/reshape_1/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/einsum_1/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/reshape_1/parallel_0/Reshape" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/einsum_1/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,de->abce" - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/einsum_1/parallel_0/einsum/Einsum/Enter" - op: "Enter" - input: "encoder/block_001/layer_000/SelfAttention/k_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/reshape_2/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 5 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 5 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000\002\000\000\000@\000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/reshape_2/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/einsum_1/parallel_0/einsum/Einsum" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/reshape_2/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/einsum_2/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/reshape_1/parallel_0/Reshape" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/einsum_2/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,de->abce" - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/einsum_2/parallel_0/einsum/Einsum/Enter" - op: "Enter" - input: "encoder/block_001/layer_000/SelfAttention/v_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/reshape_3/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 5 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 5 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000\002\000\000\000@\000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/reshape_3/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/einsum_2/parallel_0/einsum/Einsum" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/reshape_3/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/reshape_4/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 512 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/reshape_4/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/range/range/range" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/reshape_4/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/negative/parallel_0/Neg" - op: "Neg" - input: "while_loop/while/range/range/range" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/add/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/add/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/reshape_4/parallel_0/Reshape" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/add/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/add/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/add/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/add/parallel_0/transpose" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/add/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/add/parallel_0_1/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/add/parallel_0_1/transpose" - op: "Transpose" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/negative/parallel_0/Neg" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/add/parallel_0_1/transpose/perm" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/add/parallel_0_1/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/add/parallel_0_1/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/add/parallel_0_1/transpose" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/add/parallel_0_1/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/add/parallel_0_2/Add" - op: "Add" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/add/parallel_0/ExpandDims" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/add/parallel_0_1/ExpandDims" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/reshape_5/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/reshape_5/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/einsum_1/parallel_0/Sum" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/reshape_5/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/binary_op/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/binary_op/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/einsum_1/parallel_0/Sum" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/binary_op/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/binary_op/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/binary_op/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/binary_op/parallel_0/transpose" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/binary_op/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/binary_op/parallel_0_1/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/binary_op/parallel_0_1/transpose" - op: "Transpose" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/reshape_5/parallel_0/Reshape" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/binary_op/parallel_0_1/transpose/perm" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/binary_op/parallel_0_1/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 2 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/binary_op/parallel_0_1/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/binary_op/parallel_0_1/transpose" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/binary_op/parallel_0_1/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 1 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/binary_op/parallel_0_2/Equal" - op: "Equal" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/binary_op/parallel_0/ExpandDims" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/binary_op/parallel_0_1/ExpandDims" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "incompatible_shape_error" - value { - b: true - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/logical_not/parallel_0/LogicalNot" - op: "LogicalNot" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/binary_op/parallel_0_2/Equal" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/cast/parallel_0/Cast" - op: "Cast" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/logical_not/parallel_0/LogicalNot" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BOOL - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: -1000000000.0 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/cast/parallel_0/Cast" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/negative_1/parallel_0/Neg" - op: "Neg" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/add/parallel_0_2/Add" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/binary_op_1/parallel_0_1/Less" - op: "Less" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/negative_1/parallel_0/Neg" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/binary_op_1/parallel_0_1/Less/Enter" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/binary_op_1/parallel_0_1/Less/Enter" - op: "Enter" - input: "encoder/block_001/layer_000/SelfAttention/Const" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/to_int32/parallel_0/Cast" - op: "Cast" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/binary_op_1/parallel_0_1/Less" - attr { - key: "DstT" - value { - type: DT_INT32 - } - } - attr { - key: "SrcT" - value { - type: DT_BOOL - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/scalar_mul_1/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 16 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/scalar_mul_1/parallel_0/mul" - op: "Mul" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/to_int32/parallel_0/Cast" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/scalar_mul_1/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/scalar_add/parallel_0/add/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/scalar_add/parallel_0/add" - op: "AddV2" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/scalar_mul_1/parallel_0/mul" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/scalar_add/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/sign/parallel_0/Sign" - op: "Sign" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/negative_1/parallel_0/Neg" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/einsum_3/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/negative_1/parallel_0/Neg" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/sign/parallel_0/Sign" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/binary_op_2/parallel_0_1/Less" - op: "Less" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/einsum_3/parallel_0/Mul" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/binary_op_2/parallel_0_1/Less/Enter" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/binary_op_2/parallel_0_1/Less/Enter" - op: "Enter" - input: "encoder/block_001/layer_000/SelfAttention/Const_1" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/to_float/parallel_0/Cast" - op: "Cast" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/einsum_3/parallel_0/Mul" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_INT32 - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/scalar_mul_2/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.125 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/scalar_mul_2/parallel_0/mul" - op: "Mul" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/to_float/parallel_0/Cast" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/scalar_mul_2/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/log/parallel_0/Log" - op: "Log" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/scalar_mul_2/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/scalar_mul_3/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.3606737554073334 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/scalar_mul_3/parallel_0/mul" - op: "Mul" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/log/parallel_0/Log" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/scalar_mul_3/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/scalar_mul_4/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 8.0 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/scalar_mul_4/parallel_0/mul" - op: "Mul" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/scalar_mul_3/parallel_0/mul" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/scalar_mul_4/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/to_int32_1/parallel_0/Cast" - op: "Cast" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/scalar_mul_4/parallel_0/mul" - attr { - key: "DstT" - value { - type: DT_INT32 - } - } - attr { - key: "SrcT" - value { - type: DT_FLOAT - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/scalar_add_1/parallel_0/add/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 8 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/scalar_add_1/parallel_0/add" - op: "AddV2" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/to_int32_1/parallel_0/Cast" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/scalar_add_1/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/add_1/parallel_0_1/Minimum" - op: "Minimum" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/scalar_add_1/parallel_0/add" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/add_1/parallel_0_1/Minimum/Enter" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/add_1/parallel_0_1/Minimum/Enter" - op: "Enter" - input: "encoder/block_001/layer_000/SelfAttention/minimum/Const" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/cast_1/parallel_0/Cast" - op: "Cast" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/binary_op_2/parallel_0_1/Less" - attr { - key: "DstT" - value { - type: DT_INT32 - } - } - attr { - key: "SrcT" - value { - type: DT_BOOL - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/einsum_4/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/einsum_3/parallel_0/Mul" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/cast_1/parallel_0/Cast" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/logical_not_1/parallel_0/LogicalNot" - op: "LogicalNot" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/binary_op_2/parallel_0_1/Less" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/cast_2/parallel_0/Cast" - op: "Cast" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/logical_not_1/parallel_0/LogicalNot" - attr { - key: "DstT" - value { - type: DT_INT32 - } - } - attr { - key: "SrcT" - value { - type: DT_BOOL - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/einsum_5/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/add_1/parallel_0_1/Minimum" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/cast_2/parallel_0/Cast" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/add_2/parallel_0/Add" - op: "Add" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/einsum_4/parallel_0/Mul" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/einsum_5/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/add_3/parallel_0/Add" - op: "Add" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/scalar_add/parallel_0/add" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/add_2/parallel_0/Add" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/one_hot/parallel_0/sub/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/one_hot/parallel_0/sub" - op: "Sub" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/add_3/parallel_0/Add" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/one_hot/parallel_0/sub/y" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/one_hot/parallel_0/Cast/x" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/one_hot/parallel_0/Cast_1/x" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/one_hot/parallel_0/one_hot/depth" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 32 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/one_hot/parallel_0/one_hot" - op: "OneHot" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/one_hot/parallel_0/sub" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/one_hot/parallel_0/one_hot/depth" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/one_hot/parallel_0/Cast/x" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/one_hot/parallel_0/Cast_1/x" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "TI" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "axis" - value { - i: -1 - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/einsum_6/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/one_hot/parallel_0/one_hot" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/einsum_6/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abc,dc->abd" - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/add_4/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000\003\000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/add_4/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/scalar_mul/parallel_0/mul" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/add_4/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/add_4/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 4 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/add_4/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/add_4/parallel_0/transpose" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/add_4/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/add_4/parallel_0_1/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\001\000\000\000\000\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/add_4/parallel_0_1/transpose" - op: "Transpose" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/einsum_6/parallel_0/einsum/Einsum" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/add_4/parallel_0_1/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/add_4/parallel_0_1/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/add_4/parallel_0_1/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/add_4/parallel_0_1/transpose" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/add_4/parallel_0_1/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 512 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/add_4/parallel_0_1/ExpandDims_1/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/add_4/parallel_0_1/ExpandDims_1" - op: "ExpandDims" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/add_4/parallel_0_1/ExpandDims" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/add_4/parallel_0_1/ExpandDims_1/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 512 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/add_4/parallel_0_2/Add" - op: "Add" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/add_4/parallel_0/ExpandDims" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/add_4/parallel_0_1/ExpandDims_1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/einsum_7/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/reshape/parallel_0/Reshape" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/reshape_2/parallel_0/Reshape" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcde,abfde->abcdf" - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/add_5/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 5 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 5 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000\004\000\000\000\003\000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/add_5/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/add_4/parallel_0_2/Add" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/add_5/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/add_5/parallel_0_1/Add" - op: "Add" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/einsum_7/parallel_0/einsum/Einsum" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/add_5/parallel_0/transpose" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce_max/parallel_0/Max/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 4 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce_max/parallel_0/Max" - op: "Max" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/add_5/parallel_0_1/Add" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce_max/parallel_0/Max/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/negative/parallel_0/Neg" - op: "Neg" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce_max/parallel_0/Max" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/add/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000\003\000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/add/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/negative/parallel_0/Neg" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/add/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/add/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 4 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/add/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/add/parallel_0/transpose" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/add/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/add/parallel_0_1/Add" - op: "Add" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/add_5/parallel_0_1/Add" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/add/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/exp/parallel_0/Exp" - op: "Exp" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/add/parallel_0_1/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 4 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/exp/parallel_0/Exp" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/log/parallel_0/Log" - op: "Log" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/parallel_0/Sum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/add_1/parallel_0/Add" - op: "Add" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/log/parallel_0/Log" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce_max/parallel_0/Max" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/negative/parallel_0/Neg" - op: "Neg" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/add_1/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/add/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000\003\000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/add/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/negative/parallel_0/Neg" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/add/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/add/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 4 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/add/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/add/parallel_0/transpose" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/add/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/add/parallel_0_1/Add" - op: "Add" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/add_5/parallel_0_1/Add" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/add/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/exp/parallel_0/Exp" - op: "Exp" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/add/parallel_0_1/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/einsum_8/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/exp/parallel_0/Exp" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/reshape_3/parallel_0/Reshape" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcde,abedf->abcdf" - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/reshape_6/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 5 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 5 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000\002\000\000\000@\000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/reshape_6/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/einsum_8/parallel_0/einsum/Einsum" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/reshape_6/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/reshape_7/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000\200\000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/reshape_7/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/reshape_6/parallel_0/Reshape" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/reshape_7/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/einsum_9/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/reshape_7/parallel_0/Reshape" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/einsum_9/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,de->abce" - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/einsum_9/parallel_0/einsum/Einsum/Enter" - op: "Enter" - input: "encoder/block_001/layer_000/SelfAttention/o_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/einsum_9/parallel_0/einsum/Einsum" - input: "while_loop/while/encoder/block_000/layer_001/add/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/rms_norm/square/parallel_0/Square" - op: "Square" - input: "while_loop/while/encoder/block_001/layer_000/add/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/rms_norm/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/rms_norm/reduce_mean/reduce/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/encoder/block_001/layer_001/rms_norm/square/parallel_0/Square" - input: "while_loop/while/encoder/block_001/layer_001/rms_norm/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/rms_norm/reduce_mean/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.03125 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/rms_norm/reduce_mean/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/encoder/block_001/layer_001/rms_norm/reduce_mean/reduce/parallel_0/Sum" - input: "while_loop/while/encoder/block_001/layer_001/rms_norm/reduce_mean/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/scalar_add/parallel_0/add/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 9.999999974752427e-07 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/scalar_add/parallel_0/add" - op: "AddV2" - input: "while_loop/while/encoder/block_001/layer_001/rms_norm/reduce_mean/scalar_mul/parallel_0/mul" - input: "while_loop/while/encoder/block_001/layer_001/scalar_add/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/rsqrt/parallel_0/Rsqrt" - op: "Rsqrt" - input: "while_loop/while/encoder/block_001/layer_001/scalar_add/parallel_0/add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/einsum/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/einsum/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/encoder/block_001/layer_001/rsqrt/parallel_0/Rsqrt" - input: "while_loop/while/encoder/block_001/layer_001/einsum/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/einsum/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/einsum/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/encoder/block_001/layer_001/einsum/parallel_0/transpose" - input: "while_loop/while/encoder/block_001/layer_001/einsum/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_001/layer_000/add/parallel_0/Add" - input: "while_loop/while/encoder/block_001/layer_001/einsum/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/einsum_1/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/einsum_1/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/encoder/block_001/layer_001/einsum_1/parallel_0/transpose/Enter" - input: "while_loop/while/encoder/block_001/layer_001/einsum_1/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/einsum_1/parallel_0/transpose/Enter" - op: "Enter" - input: "encoder/block_001/layer_001/rms_norm/scale_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/einsum_1/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/einsum_1/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/encoder/block_001/layer_001/einsum_1/parallel_0/transpose" - input: "while_loop/while/encoder/block_001/layer_001/einsum_1/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/einsum_1/parallel_0/ExpandDims_1/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/einsum_1/parallel_0/ExpandDims_1" - op: "ExpandDims" - input: "while_loop/while/encoder/block_001/layer_001/einsum_1/parallel_0/ExpandDims" - input: "while_loop/while/encoder/block_001/layer_001/einsum_1/parallel_0/ExpandDims_1/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/einsum_1/parallel_0/ExpandDims_2/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 2 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/einsum_1/parallel_0/ExpandDims_2" - op: "ExpandDims" - input: "while_loop/while/encoder/block_001/layer_001/einsum_1/parallel_0/ExpandDims_1" - input: "while_loop/while/encoder/block_001/layer_001/einsum_1/parallel_0/ExpandDims_2/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_001/layer_001/einsum/parallel_0/Mul" - input: "while_loop/while/encoder/block_001/layer_001/einsum_1/parallel_0/ExpandDims_2" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/binary_op/parallel_0_1/NotEqual" - op: "NotEqual" - input: "while_loop/while/einsum_1/parallel_0/Sum" - input: "while_loop/while/encoder/block_001/layer_001/binary_op/parallel_0_1/NotEqual/Enter" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "incompatible_shape_error" - value { - b: true - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/binary_op/parallel_0_1/NotEqual/Enter" - op: "Enter" - input: "encoder/block_001/layer_001/Const" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/cast/parallel_0/Cast" - op: "Cast" - input: "while_loop/while/encoder/block_001/layer_001/binary_op/parallel_0_1/NotEqual" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BOOL - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/einsum_2/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/einsum_2/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/encoder/block_001/layer_001/cast/parallel_0/Cast" - input: "while_loop/while/encoder/block_001/layer_001/einsum_2/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/einsum_2/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/einsum_2/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/encoder/block_001/layer_001/einsum_2/parallel_0/transpose" - input: "while_loop/while/encoder/block_001/layer_001/einsum_2/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/einsum_2/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_001/layer_001/einsum_1/parallel_0/Mul" - input: "while_loop/while/encoder/block_001/layer_001/einsum_2/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/encoder/block_001/layer_001/einsum_2/parallel_0/Mul" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/einsum/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,de->abce" - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/einsum/parallel_0/einsum/Einsum/Enter" - op: "Enter" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.044714998453855515 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/einsum/parallel_0/einsum/Einsum" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/scalar_mul/parallel_0/mul" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/einsum/parallel_0/einsum/Einsum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/einsum_2/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/einsum_1/parallel_0/Mul" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/einsum/parallel_0/einsum/Einsum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/einsum/parallel_0/einsum/Einsum" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/einsum_2/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/scalar_mul_1/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.7978845834732056 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/scalar_mul_1/parallel_0/mul" - op: "Mul" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/add/parallel_0/Add" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/scalar_mul_1/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/tanh/parallel_0/Tanh" - op: "Tanh" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/scalar_mul_1/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/scalar_add/parallel_0/add/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/scalar_add/parallel_0/add" - op: "AddV2" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/tanh/parallel_0/Tanh" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/scalar_add/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/scalar_mul_2/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.5 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/scalar_mul_2/parallel_0/mul" - op: "Mul" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/scalar_add/parallel_0/add" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/scalar_mul_2/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/einsum_3/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/einsum/parallel_0/einsum/Einsum" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/scalar_mul_2/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_1/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/encoder/block_001/layer_001/einsum_2/parallel_0/Mul" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_1/einsum/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,de->abce" - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_1/einsum/parallel_0/einsum/Einsum/Enter" - op: "Enter" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/einsum_3/parallel_0/Mul" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_1/einsum/parallel_0/einsum/Einsum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wo/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/einsum/parallel_0/Mul" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wo/einsum/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,de->abce" - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wo/einsum/parallel_0/einsum/Einsum/Enter" - op: "Enter" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wo/einsum/parallel_0/einsum/Einsum" - input: "while_loop/while/encoder/block_001/layer_000/add/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/rms_norm/square/parallel_0/Square" - op: "Square" - input: "while_loop/while/encoder/block_001/layer_001/add/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/rms_norm/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/encoder/rms_norm/reduce_mean/reduce/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/encoder/rms_norm/square/parallel_0/Square" - input: "while_loop/while/encoder/rms_norm/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/encoder/rms_norm/reduce_mean/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.03125 - } - } - } -} -node { - name: "while_loop/while/encoder/rms_norm/reduce_mean/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/encoder/rms_norm/reduce_mean/reduce/parallel_0/Sum" - input: "while_loop/while/encoder/rms_norm/reduce_mean/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/scalar_add/parallel_0/add/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 9.999999974752427e-07 - } - } - } -} -node { - name: "while_loop/while/encoder/scalar_add/parallel_0/add" - op: "AddV2" - input: "while_loop/while/encoder/rms_norm/reduce_mean/scalar_mul/parallel_0/mul" - input: "while_loop/while/encoder/scalar_add/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/rsqrt/parallel_0/Rsqrt" - op: "Rsqrt" - input: "while_loop/while/encoder/scalar_add/parallel_0/add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/einsum_1/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/einsum_1/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/encoder/rsqrt/parallel_0/Rsqrt" - input: "while_loop/while/encoder/einsum_1/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/einsum_1/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/encoder/einsum_1/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/encoder/einsum_1/parallel_0/transpose" - input: "while_loop/while/encoder/einsum_1/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_001/layer_001/add/parallel_0/Add" - input: "while_loop/while/encoder/einsum_1/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/einsum_2/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/encoder/einsum_2/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/encoder/einsum_2/parallel_0/transpose/Enter" - input: "while_loop/while/encoder/einsum_2/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/einsum_2/parallel_0/transpose/Enter" - op: "Enter" - input: "encoder/rms_norm/scale_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/encoder/einsum_2/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/encoder/einsum_2/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/encoder/einsum_2/parallel_0/transpose" - input: "while_loop/while/encoder/einsum_2/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/einsum_2/parallel_0/ExpandDims_1/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "while_loop/while/encoder/einsum_2/parallel_0/ExpandDims_1" - op: "ExpandDims" - input: "while_loop/while/encoder/einsum_2/parallel_0/ExpandDims" - input: "while_loop/while/encoder/einsum_2/parallel_0/ExpandDims_1/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/einsum_2/parallel_0/ExpandDims_2/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 2 - } - } - } -} -node { - name: "while_loop/while/encoder/einsum_2/parallel_0/ExpandDims_2" - op: "ExpandDims" - input: "while_loop/while/encoder/einsum_2/parallel_0/ExpandDims_1" - input: "while_loop/while/encoder/einsum_2/parallel_0/ExpandDims_2/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/einsum_2/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/einsum_1/parallel_0/Mul" - input: "while_loop/while/encoder/einsum_2/parallel_0/ExpandDims_2" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/binary_op/parallel_0_1/NotEqual" - op: "NotEqual" - input: "while_loop/while/einsum_1/parallel_0/Sum" - input: "while_loop/while/encoder/binary_op/parallel_0_1/NotEqual/Enter" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "incompatible_shape_error" - value { - b: true - } - } -} -node { - name: "while_loop/while/encoder/binary_op/parallel_0_1/NotEqual/Enter" - op: "Enter" - input: "encoder/Const" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/encoder/cast/parallel_0/Cast" - op: "Cast" - input: "while_loop/while/encoder/binary_op/parallel_0_1/NotEqual" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BOOL - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/einsum_3/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/einsum_3/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/encoder/cast/parallel_0/Cast" - input: "while_loop/while/encoder/einsum_3/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/einsum_3/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/encoder/einsum_3/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/encoder/einsum_3/parallel_0/transpose" - input: "while_loop/while/encoder/einsum_3/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/einsum_3/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/einsum_2/parallel_0/Mul" - input: "while_loop/while/encoder/einsum_3/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/reshape_12/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000 \000\000\000" - } - } - } -} -node { - name: "while_loop/while/reshape_12/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/encoder/einsum_3/parallel_0/Mul" - input: "while_loop/while/reshape_12/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/reshape_13/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000" - } - } - } -} -node { - name: "while_loop/while/reshape_13/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/einsum_1/parallel_0/Sum" - input: "while_loop/while/reshape_13/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/shift/parallel_0/Slice/begin" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\000\000\000\000\000\000\000\000" - } - } - } -} -node { - name: "while_loop/while/shift/parallel_0/Slice/size" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\377\377\377\377\377\377\377\377\377\001\000\000" - } - } - } -} -node { - name: "while_loop/while/shift/parallel_0/Slice" - op: "Slice" - input: "while_loop/while/einsum_3/parallel_0/Sum" - input: "while_loop/while/shift/parallel_0/Slice/begin" - input: "while_loop/while/shift/parallel_0/Slice/size" - attr { - key: "Index" - value { - type: DT_INT32 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 511 - } - } - } - } - } -} -node { - name: "while_loop/while/shift/parallel_0/Pad/paddings" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\001\000\000\000\000\000\000\000" - } - } - } -} -node { - name: "while_loop/while/shift/parallel_0/Pad" - op: "Pad" - input: "while_loop/while/shift/parallel_0/Slice" - input: "while_loop/while/shift/parallel_0/Pad/paddings" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tpaddings" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/sign/parallel_0/Sign" - op: "Sign" - input: "while_loop/while/shift/parallel_0/Pad" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/einsum_6/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/shift/parallel_0/Pad" - input: "while_loop/while/sign/parallel_0/Sign" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/shift_1/parallel_0/Slice/begin" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\000\000\000\000\000\000\000\000" - } - } - } -} -node { - name: "while_loop/while/shift_1/parallel_0/Slice/size" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\377\377\377\377\377\377\377\377\377\001\000\000" - } - } - } -} -node { - name: "while_loop/while/shift_1/parallel_0/Slice" - op: "Slice" - input: "while_loop/while/einsum_4/parallel_0/Sum" - input: "while_loop/while/shift_1/parallel_0/Slice/begin" - input: "while_loop/while/shift_1/parallel_0/Slice/size" - attr { - key: "Index" - value { - type: DT_INT32 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 511 - } - } - } - } - } -} -node { - name: "while_loop/while/shift_1/parallel_0/Pad/paddings" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\001\000\000\000\000\000\000\000" - } - } - } -} -node { - name: "while_loop/while/shift_1/parallel_0/Pad" - op: "Pad" - input: "while_loop/while/shift_1/parallel_0/Slice" - input: "while_loop/while/shift_1/parallel_0/Pad/paddings" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tpaddings" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/binary_op_1/parallel_0/Equal" - op: "Equal" - input: "while_loop/while/einsum_4/parallel_0/Sum" - input: "while_loop/while/shift_1/parallel_0/Pad" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "incompatible_shape_error" - value { - b: true - } - } -} -node { - name: "while_loop/while/to_int32/parallel_0/Cast" - op: "Cast" - input: "while_loop/while/binary_op_1/parallel_0/Equal" - attr { - key: "DstT" - value { - type: DT_INT32 - } - } - attr { - key: "SrcT" - value { - type: DT_BOOL - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/einsum_7/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/einsum_6/parallel_0/Mul" - input: "while_loop/while/to_int32/parallel_0/Cast" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/reshape_14/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000" - } - } - } -} -node { - name: "while_loop/while/reshape_14/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/einsum/parallel_0/Sum" - input: "while_loop/while/reshape_14/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/range_1/range_1/range/start" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/range_1/range_1/range/limit" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 512 - } - } - } -} -node { - name: "while_loop/while/range_1/range_1/range/delta" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "while_loop/while/range_1/range_1/range" - op: "Range" - input: "while_loop/while/range_1/range_1/range/start" - input: "while_loop/while/range_1/range_1/range/limit" - input: "while_loop/while/range_1/range_1/range/delta" - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/one_hot/parallel_0/sub/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/decoder/one_hot/parallel_0/sub" - op: "Sub" - input: "while_loop/while/einsum_7/parallel_0/Mul" - input: "while_loop/while/decoder/one_hot/parallel_0/sub/y" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/one_hot/parallel_0/Cast/x" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "while_loop/while/decoder/one_hot/parallel_0/Cast_1/x" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/while/decoder/one_hot/parallel_0/one_hot/depth" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 32128 - } - } - } -} -node { - name: "while_loop/while/decoder/one_hot/parallel_0/one_hot" - op: "OneHot" - input: "while_loop/while/decoder/one_hot/parallel_0/sub" - input: "while_loop/while/decoder/one_hot/parallel_0/one_hot/depth" - input: "while_loop/while/decoder/one_hot/parallel_0/Cast/x" - input: "while_loop/while/decoder/one_hot/parallel_0/Cast_1/x" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "TI" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32128 - } - } - } - } - } - attr { - key: "axis" - value { - i: -1 - } - } -} -node { - name: "while_loop/while/decoder/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/one_hot/parallel_0/one_hot" - input: "while_loop/while/encoder/einsum/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,de->abce" - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/rms_norm/square/parallel_0/Square" - op: "Square" - input: "while_loop/while/decoder/einsum/parallel_0/einsum/Einsum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/rms_norm/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/rms_norm/reduce_mean/reduce/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/decoder/block_000/layer_000/rms_norm/square/parallel_0/Square" - input: "while_loop/while/decoder/block_000/layer_000/rms_norm/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/rms_norm/reduce_mean/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.03125 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/rms_norm/reduce_mean/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_000/rms_norm/reduce_mean/reduce/parallel_0/Sum" - input: "while_loop/while/decoder/block_000/layer_000/rms_norm/reduce_mean/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/scalar_add/parallel_0/add/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 9.999999974752427e-07 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/scalar_add/parallel_0/add" - op: "AddV2" - input: "while_loop/while/decoder/block_000/layer_000/rms_norm/reduce_mean/scalar_mul/parallel_0/mul" - input: "while_loop/while/decoder/block_000/layer_000/scalar_add/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/rsqrt/parallel_0/Rsqrt" - op: "Rsqrt" - input: "while_loop/while/decoder/block_000/layer_000/scalar_add/parallel_0/add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/einsum/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/einsum/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_000/layer_000/rsqrt/parallel_0/Rsqrt" - input: "while_loop/while/decoder/block_000/layer_000/einsum/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/einsum/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/einsum/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_000/layer_000/einsum/parallel_0/transpose" - input: "while_loop/while/decoder/block_000/layer_000/einsum/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/einsum/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/block_000/layer_000/einsum/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/einsum_1/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/einsum_1/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_000/layer_000/einsum_1/parallel_0/transpose/Enter" - input: "while_loop/while/decoder/block_000/layer_000/einsum_1/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/einsum_1/parallel_0/transpose/Enter" - op: "Enter" - input: "decoder/block_000/layer_000/rms_norm/scale_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/einsum_1/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/einsum_1/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_000/layer_000/einsum_1/parallel_0/transpose" - input: "while_loop/while/decoder/block_000/layer_000/einsum_1/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/einsum_1/parallel_0/ExpandDims_1/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/einsum_1/parallel_0/ExpandDims_1" - op: "ExpandDims" - input: "while_loop/while/decoder/block_000/layer_000/einsum_1/parallel_0/ExpandDims" - input: "while_loop/while/decoder/block_000/layer_000/einsum_1/parallel_0/ExpandDims_1/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/einsum_1/parallel_0/ExpandDims_2/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 2 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/einsum_1/parallel_0/ExpandDims_2" - op: "ExpandDims" - input: "while_loop/while/decoder/block_000/layer_000/einsum_1/parallel_0/ExpandDims_1" - input: "while_loop/while/decoder/block_000/layer_000/einsum_1/parallel_0/ExpandDims_2/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_000/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_000/layer_000/einsum_1/parallel_0/ExpandDims_2" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/binary_op/parallel_0_1/NotEqual" - op: "NotEqual" - input: "while_loop/while/einsum_4/parallel_0/Sum" - input: "while_loop/while/decoder/block_000/layer_000/binary_op/parallel_0_1/NotEqual/Enter" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "incompatible_shape_error" - value { - b: true - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/binary_op/parallel_0_1/NotEqual/Enter" - op: "Enter" - input: "decoder/block_000/layer_000/Const" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/cast/parallel_0/Cast" - op: "Cast" - input: "while_loop/while/decoder/block_000/layer_000/binary_op/parallel_0_1/NotEqual" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BOOL - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/einsum_2/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/einsum_2/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_000/layer_000/cast/parallel_0/Cast" - input: "while_loop/while/decoder/block_000/layer_000/einsum_2/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/einsum_2/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/einsum_2/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_000/layer_000/einsum_2/parallel_0/transpose" - input: "while_loop/while/decoder/block_000/layer_000/einsum_2/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/einsum_2/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_000/einsum_1/parallel_0/Mul" - input: "while_loop/while/decoder/block_000/layer_000/einsum_2/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_000/layer_000/einsum_2/parallel_0/Mul" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/einsum/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,de->abce" - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/einsum/parallel_0/einsum/Einsum/Enter" - op: "Enter" - input: "decoder/block_000/layer_000/SelfAttention/q_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/reshape/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 5 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 5 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000\002\000\000\000@\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/reshape/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/einsum/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/reshape/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/reshape_1/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000 \000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/reshape_1/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/decoder/block_000/layer_000/einsum_2/parallel_0/Mul" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/reshape_1/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/einsum_1/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/reshape_1/parallel_0/Reshape" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/einsum_1/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,de->abce" - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/einsum_1/parallel_0/einsum/Einsum/Enter" - op: "Enter" - input: "decoder/block_000/layer_000/SelfAttention/k_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/reshape_2/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 5 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 5 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000\002\000\000\000@\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/reshape_2/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/einsum_1/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/reshape_2/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/einsum_2/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/reshape_1/parallel_0/Reshape" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/einsum_2/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,de->abce" - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/einsum_2/parallel_0/einsum/Einsum/Enter" - op: "Enter" - input: "decoder/block_000/layer_000/SelfAttention/v_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/reshape_3/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 5 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 5 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000\002\000\000\000@\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/reshape_3/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/einsum_2/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/reshape_3/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/reshape_4/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 512 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/reshape_4/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/range_1/range_1/range" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/reshape_4/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/negative/parallel_0/Neg" - op: "Neg" - input: "while_loop/while/range_1/range_1/range" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/reshape_4/parallel_0/Reshape" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add/parallel_0/transpose" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add/parallel_0_1/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add/parallel_0_1/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/negative/parallel_0/Neg" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add/parallel_0_1/transpose/perm" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add/parallel_0_1/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add/parallel_0_1/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add/parallel_0_1/transpose" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add/parallel_0_1/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add/parallel_0_2/Add" - op: "Add" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add/parallel_0/ExpandDims" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add/parallel_0_1/ExpandDims" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/reshape_5/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 512 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/reshape_5/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/range_1/range_1/range" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/reshape_5/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/binary_op/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/binary_op/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/range_1/range_1/range" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/binary_op/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/binary_op/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/binary_op/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/binary_op/parallel_0/transpose" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/binary_op/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/binary_op/parallel_0_1/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/binary_op/parallel_0_1/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/reshape_5/parallel_0/Reshape" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/binary_op/parallel_0_1/transpose/perm" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/binary_op/parallel_0_1/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/binary_op/parallel_0_1/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/binary_op/parallel_0_1/transpose" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/binary_op/parallel_0_1/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/binary_op/parallel_0_2/GreaterEqual" - op: "GreaterEqual" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/binary_op/parallel_0/ExpandDims" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/binary_op/parallel_0_1/ExpandDims" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/logical_not/parallel_0/LogicalNot" - op: "LogicalNot" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/binary_op/parallel_0_2/GreaterEqual" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/cast/parallel_0/Cast" - op: "Cast" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/logical_not/parallel_0/LogicalNot" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BOOL - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: -1000000000.0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/cast/parallel_0/Cast" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/reshape_6/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/reshape_6/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/einsum_4/parallel_0/Sum" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/reshape_6/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/binary_op_1/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/binary_op_1/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/einsum_4/parallel_0/Sum" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/binary_op_1/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/binary_op_1/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/binary_op_1/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/binary_op_1/parallel_0/transpose" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/binary_op_1/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/binary_op_1/parallel_0_1/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/binary_op_1/parallel_0_1/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/reshape_6/parallel_0/Reshape" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/binary_op_1/parallel_0_1/transpose/perm" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/binary_op_1/parallel_0_1/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 2 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/binary_op_1/parallel_0_1/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/binary_op_1/parallel_0_1/transpose" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/binary_op_1/parallel_0_1/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 1 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/binary_op_1/parallel_0_2/Equal" - op: "Equal" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/binary_op_1/parallel_0/ExpandDims" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/binary_op_1/parallel_0_1/ExpandDims" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "incompatible_shape_error" - value { - b: true - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/logical_not_1/parallel_0/LogicalNot" - op: "LogicalNot" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/binary_op_1/parallel_0_2/Equal" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/cast_1/parallel_0/Cast" - op: "Cast" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/logical_not_1/parallel_0/LogicalNot" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BOOL - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/scalar_mul_1/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: -1000000000.0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/scalar_mul_1/parallel_0/mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/cast_1/parallel_0/Cast" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/scalar_mul_1/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/negative_1/parallel_0/Neg" - op: "Neg" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add/parallel_0_2/Add" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add_1/parallel_0_1/Maximum" - op: "Maximum" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/negative_1/parallel_0/Neg" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add_1/parallel_0_1/Maximum/Enter" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add_1/parallel_0_1/Maximum/Enter" - op: "Enter" - input: "decoder/block_000/layer_000/SelfAttention/maximum/Const" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/binary_op_2/parallel_0_1/Less" - op: "Less" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add_1/parallel_0_1/Maximum" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/binary_op_2/parallel_0_1/Less/Enter" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/binary_op_2/parallel_0_1/Less/Enter" - op: "Enter" - input: "decoder/block_000/layer_000/SelfAttention/Const" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/to_float/parallel_0/Cast" - op: "Cast" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add_1/parallel_0_1/Maximum" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_INT32 - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/scalar_mul_2/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0625 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/scalar_mul_2/parallel_0/mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/to_float/parallel_0/Cast" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/scalar_mul_2/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/log/parallel_0/Log" - op: "Log" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/scalar_mul_2/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/scalar_mul_3/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.48089835047721863 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/scalar_mul_3/parallel_0/mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/log/parallel_0/Log" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/scalar_mul_3/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/scalar_mul_4/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 16.0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/scalar_mul_4/parallel_0/mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/scalar_mul_3/parallel_0/mul" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/scalar_mul_4/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/to_int32/parallel_0/Cast" - op: "Cast" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/scalar_mul_4/parallel_0/mul" - attr { - key: "DstT" - value { - type: DT_INT32 - } - } - attr { - key: "SrcT" - value { - type: DT_FLOAT - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/scalar_add/parallel_0/add/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 16 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/scalar_add/parallel_0/add" - op: "AddV2" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/to_int32/parallel_0/Cast" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/scalar_add/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add_2/parallel_0_1/Minimum" - op: "Minimum" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/scalar_add/parallel_0/add" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add_2/parallel_0_1/Minimum/Enter" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add_2/parallel_0_1/Minimum/Enter" - op: "Enter" - input: "decoder/block_000/layer_000/SelfAttention/minimum/Const" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/cast_2/parallel_0/Cast" - op: "Cast" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/binary_op_2/parallel_0_1/Less" - attr { - key: "DstT" - value { - type: DT_INT32 - } - } - attr { - key: "SrcT" - value { - type: DT_BOOL - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/einsum_3/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add_1/parallel_0_1/Maximum" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/cast_2/parallel_0/Cast" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/logical_not_2/parallel_0/LogicalNot" - op: "LogicalNot" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/binary_op_2/parallel_0_1/Less" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/cast_3/parallel_0/Cast" - op: "Cast" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/logical_not_2/parallel_0/LogicalNot" - attr { - key: "DstT" - value { - type: DT_INT32 - } - } - attr { - key: "SrcT" - value { - type: DT_BOOL - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/einsum_4/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add_2/parallel_0_1/Minimum" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/cast_3/parallel_0/Cast" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add_3/parallel_0/Add" - op: "Add" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/einsum_3/parallel_0/Mul" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/einsum_4/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/scalar_add_1/parallel_0/add/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/scalar_add_1/parallel_0/add" - op: "AddV2" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add_3/parallel_0/Add" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/scalar_add_1/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/one_hot/parallel_0/sub/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/one_hot/parallel_0/sub" - op: "Sub" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/scalar_add_1/parallel_0/add" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/one_hot/parallel_0/sub/y" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/one_hot/parallel_0/Cast/x" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/one_hot/parallel_0/Cast_1/x" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/one_hot/parallel_0/one_hot/depth" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 32 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/one_hot/parallel_0/one_hot" - op: "OneHot" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/one_hot/parallel_0/sub" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/one_hot/parallel_0/one_hot/depth" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/one_hot/parallel_0/Cast/x" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/one_hot/parallel_0/Cast_1/x" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "TI" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "axis" - value { - i: -1 - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/einsum_5/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/one_hot/parallel_0/one_hot" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/einsum_5/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abc,dc->abd" - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/einsum_5/parallel_0/einsum/Einsum/Enter" - op: "Enter" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add_4/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add_4/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/scalar_mul/parallel_0/mul" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add_4/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add_4/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add_4/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add_4/parallel_0/transpose" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add_4/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add_4/parallel_0/ExpandDims_1/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add_4/parallel_0/ExpandDims_1" - op: "ExpandDims" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add_4/parallel_0/ExpandDims" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add_4/parallel_0/ExpandDims_1/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add_4/parallel_0_1/Add" - op: "Add" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add_4/parallel_0/ExpandDims_1" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/scalar_mul_1/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add_5/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000\003\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add_5/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add_4/parallel_0_1/Add" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add_5/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add_5/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 4 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add_5/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add_5/parallel_0/transpose" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add_5/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add_5/parallel_0_1/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\001\000\000\000\000\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add_5/parallel_0_1/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/einsum_5/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add_5/parallel_0_1/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add_5/parallel_0_1/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add_5/parallel_0_1/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add_5/parallel_0_1/transpose" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add_5/parallel_0_1/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 512 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add_5/parallel_0_1/ExpandDims_1/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add_5/parallel_0_1/ExpandDims_1" - op: "ExpandDims" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add_5/parallel_0_1/ExpandDims" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add_5/parallel_0_1/ExpandDims_1/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 512 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add_5/parallel_0_2/Add" - op: "Add" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add_5/parallel_0/ExpandDims" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add_5/parallel_0_1/ExpandDims_1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/einsum_6/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/reshape/parallel_0/Reshape" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/reshape_2/parallel_0/Reshape" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcde,abfde->abcdf" - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add_6/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 5 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 5 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000\004\000\000\000\003\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add_6/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add_5/parallel_0_2/Add" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add_6/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add_6/parallel_0_1/Add" - op: "Add" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/einsum_6/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add_6/parallel_0/transpose" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce_max/parallel_0/Max/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 4 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce_max/parallel_0/Max" - op: "Max" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add_6/parallel_0_1/Add" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce_max/parallel_0/Max/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/negative/parallel_0/Neg" - op: "Neg" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce_max/parallel_0/Max" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/add/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000\003\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/add/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/negative/parallel_0/Neg" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/add/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/add/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 4 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/add/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/add/parallel_0/transpose" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/add/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/add/parallel_0_1/Add" - op: "Add" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add_6/parallel_0_1/Add" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/add/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/exp/parallel_0/Exp" - op: "Exp" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/add/parallel_0_1/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 4 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/exp/parallel_0/Exp" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/log/parallel_0/Log" - op: "Log" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/parallel_0/Sum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/add_1/parallel_0/Add" - op: "Add" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/log/parallel_0/Log" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce_max/parallel_0/Max" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/negative/parallel_0/Neg" - op: "Neg" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/add_1/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/add/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000\003\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/add/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/negative/parallel_0/Neg" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/add/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/add/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 4 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/add/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/add/parallel_0/transpose" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/add/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/add/parallel_0_1/Add" - op: "Add" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add_6/parallel_0_1/Add" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/add/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/exp/parallel_0/Exp" - op: "Exp" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/add/parallel_0_1/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/einsum_7/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/exp/parallel_0/Exp" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/reshape_3/parallel_0/Reshape" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcde,abedf->abcdf" - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/reshape_7/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 5 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 5 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000\002\000\000\000@\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/reshape_7/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/einsum_7/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/reshape_7/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/reshape_8/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000\200\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/reshape_8/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/reshape_7/parallel_0/Reshape" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/reshape_8/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/einsum_8/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/reshape_8/parallel_0/Reshape" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/einsum_8/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,de->abce" - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/einsum_8/parallel_0/einsum/Einsum/Enter" - op: "Enter" - input: "decoder/block_000/layer_000/SelfAttention/o_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/einsum_8/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/einsum/parallel_0/einsum/Einsum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/rms_norm/square/parallel_0/Square" - op: "Square" - input: "while_loop/while/decoder/block_000/layer_000/add/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/rms_norm/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/rms_norm/reduce_mean/reduce/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/decoder/block_000/layer_001/rms_norm/square/parallel_0/Square" - input: "while_loop/while/decoder/block_000/layer_001/rms_norm/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/rms_norm/reduce_mean/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.03125 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/rms_norm/reduce_mean/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_001/rms_norm/reduce_mean/reduce/parallel_0/Sum" - input: "while_loop/while/decoder/block_000/layer_001/rms_norm/reduce_mean/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/scalar_add/parallel_0/add/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 9.999999974752427e-07 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/scalar_add/parallel_0/add" - op: "AddV2" - input: "while_loop/while/decoder/block_000/layer_001/rms_norm/reduce_mean/scalar_mul/parallel_0/mul" - input: "while_loop/while/decoder/block_000/layer_001/scalar_add/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/rsqrt/parallel_0/Rsqrt" - op: "Rsqrt" - input: "while_loop/while/decoder/block_000/layer_001/scalar_add/parallel_0/add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/einsum/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/einsum/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_000/layer_001/rsqrt/parallel_0/Rsqrt" - input: "while_loop/while/decoder/block_000/layer_001/einsum/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/einsum/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/einsum/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_000/layer_001/einsum/parallel_0/transpose" - input: "while_loop/while/decoder/block_000/layer_001/einsum/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_000/add/parallel_0/Add" - input: "while_loop/while/decoder/block_000/layer_001/einsum/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/einsum_1/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/einsum_1/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_000/layer_001/einsum_1/parallel_0/transpose/Enter" - input: "while_loop/while/decoder/block_000/layer_001/einsum_1/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/einsum_1/parallel_0/transpose/Enter" - op: "Enter" - input: "decoder/block_000/layer_001/rms_norm/scale_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/einsum_1/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/einsum_1/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_000/layer_001/einsum_1/parallel_0/transpose" - input: "while_loop/while/decoder/block_000/layer_001/einsum_1/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/einsum_1/parallel_0/ExpandDims_1/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/einsum_1/parallel_0/ExpandDims_1" - op: "ExpandDims" - input: "while_loop/while/decoder/block_000/layer_001/einsum_1/parallel_0/ExpandDims" - input: "while_loop/while/decoder/block_000/layer_001/einsum_1/parallel_0/ExpandDims_1/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/einsum_1/parallel_0/ExpandDims_2/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 2 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/einsum_1/parallel_0/ExpandDims_2" - op: "ExpandDims" - input: "while_loop/while/decoder/block_000/layer_001/einsum_1/parallel_0/ExpandDims_1" - input: "while_loop/while/decoder/block_000/layer_001/einsum_1/parallel_0/ExpandDims_2/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_001/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_000/layer_001/einsum_1/parallel_0/ExpandDims_2" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/binary_op/parallel_0_1/NotEqual" - op: "NotEqual" - input: "while_loop/while/einsum_4/parallel_0/Sum" - input: "while_loop/while/decoder/block_000/layer_001/binary_op/parallel_0_1/NotEqual/Enter" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "incompatible_shape_error" - value { - b: true - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/binary_op/parallel_0_1/NotEqual/Enter" - op: "Enter" - input: "decoder/block_000/layer_001/Const" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/cast/parallel_0/Cast" - op: "Cast" - input: "while_loop/while/decoder/block_000/layer_001/binary_op/parallel_0_1/NotEqual" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BOOL - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/einsum_2/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/einsum_2/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_000/layer_001/cast/parallel_0/Cast" - input: "while_loop/while/decoder/block_000/layer_001/einsum_2/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/einsum_2/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/einsum_2/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_000/layer_001/einsum_2/parallel_0/transpose" - input: "while_loop/while/decoder/block_000/layer_001/einsum_2/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/einsum_2/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_001/einsum_1/parallel_0/Mul" - input: "while_loop/while/decoder/block_000/layer_001/einsum_2/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_000/layer_001/einsum_2/parallel_0/Mul" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/einsum/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,de->abce" - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/einsum/parallel_0/einsum/Einsum/Enter" - op: "Enter" - input: "decoder/block_000/layer_001/EncDecAttention/q_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/reshape/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 5 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 5 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000\002\000\000\000@\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/reshape/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/einsum/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/reshape/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/einsum_1/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/reshape_12/parallel_0/Reshape" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/einsum_1/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,de->abce" - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/einsum_1/parallel_0/einsum/Einsum/Enter" - op: "Enter" - input: "decoder/block_000/layer_001/EncDecAttention/k_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/reshape_1/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 5 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 5 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000\002\000\000\000@\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/reshape_1/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/einsum_1/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/reshape_1/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/einsum_2/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/reshape_12/parallel_0/Reshape" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/einsum_2/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,de->abce" - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/einsum_2/parallel_0/einsum/Einsum/Enter" - op: "Enter" - input: "decoder/block_000/layer_001/EncDecAttention/v_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/reshape_2/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 5 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 5 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000\002\000\000\000@\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/reshape_2/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/einsum_2/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/reshape_2/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/binary_op/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/binary_op/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/einsum_4/parallel_0/Sum" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/binary_op/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/binary_op/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/binary_op/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/binary_op/parallel_0/transpose" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/binary_op/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/binary_op/parallel_0_1/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/binary_op/parallel_0_1/transpose" - op: "Transpose" - input: "while_loop/while/reshape_13/parallel_0/Reshape" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/binary_op/parallel_0_1/transpose/perm" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/binary_op/parallel_0_1/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 2 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/binary_op/parallel_0_1/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/binary_op/parallel_0_1/transpose" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/binary_op/parallel_0_1/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 1 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/binary_op/parallel_0_2/Equal" - op: "Equal" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/binary_op/parallel_0/ExpandDims" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/binary_op/parallel_0_1/ExpandDims" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "incompatible_shape_error" - value { - b: true - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/logical_not/parallel_0/LogicalNot" - op: "LogicalNot" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/binary_op/parallel_0_2/Equal" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/cast/parallel_0/Cast" - op: "Cast" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/logical_not/parallel_0/LogicalNot" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BOOL - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: -1000000000.0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/cast/parallel_0/Cast" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/einsum_3/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/reshape/parallel_0/Reshape" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/reshape_1/parallel_0/Reshape" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcde,abfde->abcdf" - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/add/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000\003\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/add/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/scalar_mul/parallel_0/mul" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/add/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/add/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/add/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/add/parallel_0/transpose" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/add/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 1 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/add/parallel_0_1/Add" - op: "Add" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/einsum_3/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/add/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/reduce_logsumexp/reduce_max/parallel_0/Max/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 4 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/reduce_logsumexp/reduce_max/parallel_0/Max" - op: "Max" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/add/parallel_0_1/Add" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/reduce_logsumexp/reduce_max/parallel_0/Max/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/reduce_logsumexp/negative/parallel_0/Neg" - op: "Neg" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/reduce_logsumexp/reduce_max/parallel_0/Max" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/reduce_logsumexp/add/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000\003\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/reduce_logsumexp/add/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/reduce_logsumexp/negative/parallel_0/Neg" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/reduce_logsumexp/add/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/reduce_logsumexp/add/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 4 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/reduce_logsumexp/add/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/reduce_logsumexp/add/parallel_0/transpose" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/reduce_logsumexp/add/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/reduce_logsumexp/add/parallel_0_1/Add" - op: "Add" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/add/parallel_0_1/Add" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/reduce_logsumexp/add/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/reduce_logsumexp/exp/parallel_0/Exp" - op: "Exp" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/reduce_logsumexp/add/parallel_0_1/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/reduce_logsumexp/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 4 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/reduce_logsumexp/reduce/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/reduce_logsumexp/exp/parallel_0/Exp" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/reduce_logsumexp/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/reduce_logsumexp/log/parallel_0/Log" - op: "Log" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/reduce_logsumexp/reduce/parallel_0/Sum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/reduce_logsumexp/add_1/parallel_0/Add" - op: "Add" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/reduce_logsumexp/log/parallel_0/Log" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/reduce_logsumexp/reduce_max/parallel_0/Max" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/negative/parallel_0/Neg" - op: "Neg" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/reduce_logsumexp/add_1/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/add/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000\003\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/add/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/negative/parallel_0/Neg" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/add/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/add/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 4 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/add/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/add/parallel_0/transpose" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/add/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/add/parallel_0_1/Add" - op: "Add" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/add/parallel_0_1/Add" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/add/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/exp/parallel_0/Exp" - op: "Exp" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/add/parallel_0_1/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/einsum_4/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/exp/parallel_0/Exp" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/reshape_2/parallel_0/Reshape" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcde,abedf->abcdf" - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/reshape_3/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 5 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 5 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000\002\000\000\000@\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/reshape_3/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/einsum_4/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/reshape_3/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/reshape_4/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000\200\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/reshape_4/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/reshape_3/parallel_0/Reshape" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/reshape_4/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/einsum_5/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/reshape_4/parallel_0/Reshape" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/einsum_5/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,de->abce" - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/einsum_5/parallel_0/einsum/Einsum/Enter" - op: "Enter" - input: "decoder/block_000/layer_001/EncDecAttention/o_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/einsum_5/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/block_000/layer_000/add/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/rms_norm/square/parallel_0/Square" - op: "Square" - input: "while_loop/while/decoder/block_000/layer_001/add/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/rms_norm/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/rms_norm/reduce_mean/reduce/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/decoder/block_000/layer_002/rms_norm/square/parallel_0/Square" - input: "while_loop/while/decoder/block_000/layer_002/rms_norm/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/rms_norm/reduce_mean/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.03125 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/rms_norm/reduce_mean/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_002/rms_norm/reduce_mean/reduce/parallel_0/Sum" - input: "while_loop/while/decoder/block_000/layer_002/rms_norm/reduce_mean/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/scalar_add/parallel_0/add/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 9.999999974752427e-07 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/scalar_add/parallel_0/add" - op: "AddV2" - input: "while_loop/while/decoder/block_000/layer_002/rms_norm/reduce_mean/scalar_mul/parallel_0/mul" - input: "while_loop/while/decoder/block_000/layer_002/scalar_add/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/rsqrt/parallel_0/Rsqrt" - op: "Rsqrt" - input: "while_loop/while/decoder/block_000/layer_002/scalar_add/parallel_0/add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/einsum/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/einsum/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_000/layer_002/rsqrt/parallel_0/Rsqrt" - input: "while_loop/while/decoder/block_000/layer_002/einsum/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/einsum/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/einsum/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_000/layer_002/einsum/parallel_0/transpose" - input: "while_loop/while/decoder/block_000/layer_002/einsum/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_001/add/parallel_0/Add" - input: "while_loop/while/decoder/block_000/layer_002/einsum/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/einsum_1/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/einsum_1/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_000/layer_002/einsum_1/parallel_0/transpose/Enter" - input: "while_loop/while/decoder/block_000/layer_002/einsum_1/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/einsum_1/parallel_0/transpose/Enter" - op: "Enter" - input: "decoder/block_000/layer_002/rms_norm/scale_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/einsum_1/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/einsum_1/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_000/layer_002/einsum_1/parallel_0/transpose" - input: "while_loop/while/decoder/block_000/layer_002/einsum_1/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/einsum_1/parallel_0/ExpandDims_1/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/einsum_1/parallel_0/ExpandDims_1" - op: "ExpandDims" - input: "while_loop/while/decoder/block_000/layer_002/einsum_1/parallel_0/ExpandDims" - input: "while_loop/while/decoder/block_000/layer_002/einsum_1/parallel_0/ExpandDims_1/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/einsum_1/parallel_0/ExpandDims_2/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 2 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/einsum_1/parallel_0/ExpandDims_2" - op: "ExpandDims" - input: "while_loop/while/decoder/block_000/layer_002/einsum_1/parallel_0/ExpandDims_1" - input: "while_loop/while/decoder/block_000/layer_002/einsum_1/parallel_0/ExpandDims_2/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_002/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_000/layer_002/einsum_1/parallel_0/ExpandDims_2" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/binary_op/parallel_0_1/NotEqual" - op: "NotEqual" - input: "while_loop/while/einsum_4/parallel_0/Sum" - input: "while_loop/while/decoder/block_000/layer_002/binary_op/parallel_0_1/NotEqual/Enter" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "incompatible_shape_error" - value { - b: true - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/binary_op/parallel_0_1/NotEqual/Enter" - op: "Enter" - input: "decoder/block_000/layer_002/Const" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/cast/parallel_0/Cast" - op: "Cast" - input: "while_loop/while/decoder/block_000/layer_002/binary_op/parallel_0_1/NotEqual" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BOOL - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/einsum_2/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/einsum_2/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_000/layer_002/cast/parallel_0/Cast" - input: "while_loop/while/decoder/block_000/layer_002/einsum_2/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/einsum_2/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/einsum_2/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_000/layer_002/einsum_2/parallel_0/transpose" - input: "while_loop/while/decoder/block_000/layer_002/einsum_2/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/einsum_2/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_002/einsum_1/parallel_0/Mul" - input: "while_loop/while/decoder/block_000/layer_002/einsum_2/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_000/layer_002/einsum_2/parallel_0/Mul" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/einsum/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,de->abce" - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/einsum/parallel_0/einsum/Einsum/Enter" - op: "Enter" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.044714998453855515 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/einsum/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/scalar_mul/parallel_0/mul" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/einsum/parallel_0/einsum/Einsum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/einsum_2/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/einsum_1/parallel_0/Mul" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/einsum/parallel_0/einsum/Einsum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/einsum/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/einsum_2/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/scalar_mul_1/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.7978845834732056 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/scalar_mul_1/parallel_0/mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/add/parallel_0/Add" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/scalar_mul_1/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/tanh/parallel_0/Tanh" - op: "Tanh" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/scalar_mul_1/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/scalar_add/parallel_0/add/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/scalar_add/parallel_0/add" - op: "AddV2" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/tanh/parallel_0/Tanh" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/scalar_add/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/scalar_mul_2/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.5 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/scalar_mul_2/parallel_0/mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/scalar_add/parallel_0/add" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/scalar_mul_2/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/einsum_3/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/einsum/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/scalar_mul_2/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_1/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_000/layer_002/einsum_2/parallel_0/Mul" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_1/einsum/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,de->abce" - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_1/einsum/parallel_0/einsum/Einsum/Enter" - op: "Enter" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/einsum_3/parallel_0/Mul" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_1/einsum/parallel_0/einsum/Einsum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wo/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wo/einsum/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,de->abce" - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wo/einsum/parallel_0/einsum/Einsum/Enter" - op: "Enter" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wo/einsum/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/block_000/layer_001/add/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/rms_norm/square/parallel_0/Square" - op: "Square" - input: "while_loop/while/decoder/block_000/layer_002/add/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/rms_norm/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/rms_norm/reduce_mean/reduce/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/decoder/block_001/layer_000/rms_norm/square/parallel_0/Square" - input: "while_loop/while/decoder/block_001/layer_000/rms_norm/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/rms_norm/reduce_mean/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.03125 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/rms_norm/reduce_mean/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_000/rms_norm/reduce_mean/reduce/parallel_0/Sum" - input: "while_loop/while/decoder/block_001/layer_000/rms_norm/reduce_mean/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/scalar_add/parallel_0/add/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 9.999999974752427e-07 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/scalar_add/parallel_0/add" - op: "AddV2" - input: "while_loop/while/decoder/block_001/layer_000/rms_norm/reduce_mean/scalar_mul/parallel_0/mul" - input: "while_loop/while/decoder/block_001/layer_000/scalar_add/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/rsqrt/parallel_0/Rsqrt" - op: "Rsqrt" - input: "while_loop/while/decoder/block_001/layer_000/scalar_add/parallel_0/add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/einsum/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/einsum/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_001/layer_000/rsqrt/parallel_0/Rsqrt" - input: "while_loop/while/decoder/block_001/layer_000/einsum/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/einsum/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/einsum/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_001/layer_000/einsum/parallel_0/transpose" - input: "while_loop/while/decoder/block_001/layer_000/einsum/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_002/add/parallel_0/Add" - input: "while_loop/while/decoder/block_001/layer_000/einsum/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/einsum_1/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/einsum_1/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_001/layer_000/einsum_1/parallel_0/transpose/Enter" - input: "while_loop/while/decoder/block_001/layer_000/einsum_1/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/einsum_1/parallel_0/transpose/Enter" - op: "Enter" - input: "decoder/block_001/layer_000/rms_norm/scale_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/einsum_1/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/einsum_1/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_001/layer_000/einsum_1/parallel_0/transpose" - input: "while_loop/while/decoder/block_001/layer_000/einsum_1/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/einsum_1/parallel_0/ExpandDims_1/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/einsum_1/parallel_0/ExpandDims_1" - op: "ExpandDims" - input: "while_loop/while/decoder/block_001/layer_000/einsum_1/parallel_0/ExpandDims" - input: "while_loop/while/decoder/block_001/layer_000/einsum_1/parallel_0/ExpandDims_1/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/einsum_1/parallel_0/ExpandDims_2/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 2 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/einsum_1/parallel_0/ExpandDims_2" - op: "ExpandDims" - input: "while_loop/while/decoder/block_001/layer_000/einsum_1/parallel_0/ExpandDims_1" - input: "while_loop/while/decoder/block_001/layer_000/einsum_1/parallel_0/ExpandDims_2/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_000/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_001/layer_000/einsum_1/parallel_0/ExpandDims_2" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/binary_op/parallel_0_1/NotEqual" - op: "NotEqual" - input: "while_loop/while/einsum_4/parallel_0/Sum" - input: "while_loop/while/decoder/block_001/layer_000/binary_op/parallel_0_1/NotEqual/Enter" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "incompatible_shape_error" - value { - b: true - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/binary_op/parallel_0_1/NotEqual/Enter" - op: "Enter" - input: "decoder/block_001/layer_000/Const" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/cast/parallel_0/Cast" - op: "Cast" - input: "while_loop/while/decoder/block_001/layer_000/binary_op/parallel_0_1/NotEqual" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BOOL - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/einsum_2/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/einsum_2/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_001/layer_000/cast/parallel_0/Cast" - input: "while_loop/while/decoder/block_001/layer_000/einsum_2/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/einsum_2/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/einsum_2/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_001/layer_000/einsum_2/parallel_0/transpose" - input: "while_loop/while/decoder/block_001/layer_000/einsum_2/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/einsum_2/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_000/einsum_1/parallel_0/Mul" - input: "while_loop/while/decoder/block_001/layer_000/einsum_2/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_001/layer_000/einsum_2/parallel_0/Mul" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/einsum/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,de->abce" - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/einsum/parallel_0/einsum/Einsum/Enter" - op: "Enter" - input: "decoder/block_001/layer_000/SelfAttention/q_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/reshape/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 5 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 5 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000\002\000\000\000@\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/reshape/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/einsum/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/reshape/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/reshape_1/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000 \000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/reshape_1/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/decoder/block_001/layer_000/einsum_2/parallel_0/Mul" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/reshape_1/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/einsum_1/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/reshape_1/parallel_0/Reshape" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/einsum_1/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,de->abce" - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/einsum_1/parallel_0/einsum/Einsum/Enter" - op: "Enter" - input: "decoder/block_001/layer_000/SelfAttention/k_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/reshape_2/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 5 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 5 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000\002\000\000\000@\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/reshape_2/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/einsum_1/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/reshape_2/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/einsum_2/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/reshape_1/parallel_0/Reshape" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/einsum_2/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,de->abce" - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/einsum_2/parallel_0/einsum/Einsum/Enter" - op: "Enter" - input: "decoder/block_001/layer_000/SelfAttention/v_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/reshape_3/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 5 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 5 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000\002\000\000\000@\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/reshape_3/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/einsum_2/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/reshape_3/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/reshape_4/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 512 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/reshape_4/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/range_1/range_1/range" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/reshape_4/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/negative/parallel_0/Neg" - op: "Neg" - input: "while_loop/while/range_1/range_1/range" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/reshape_4/parallel_0/Reshape" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add/parallel_0/transpose" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add/parallel_0_1/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add/parallel_0_1/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/negative/parallel_0/Neg" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add/parallel_0_1/transpose/perm" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add/parallel_0_1/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add/parallel_0_1/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add/parallel_0_1/transpose" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add/parallel_0_1/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add/parallel_0_2/Add" - op: "Add" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add/parallel_0/ExpandDims" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add/parallel_0_1/ExpandDims" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/reshape_5/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 512 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/reshape_5/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/range_1/range_1/range" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/reshape_5/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/binary_op/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/binary_op/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/range_1/range_1/range" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/binary_op/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/binary_op/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/binary_op/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/binary_op/parallel_0/transpose" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/binary_op/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/binary_op/parallel_0_1/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/binary_op/parallel_0_1/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/reshape_5/parallel_0/Reshape" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/binary_op/parallel_0_1/transpose/perm" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/binary_op/parallel_0_1/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/binary_op/parallel_0_1/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/binary_op/parallel_0_1/transpose" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/binary_op/parallel_0_1/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/binary_op/parallel_0_2/GreaterEqual" - op: "GreaterEqual" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/binary_op/parallel_0/ExpandDims" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/binary_op/parallel_0_1/ExpandDims" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/logical_not/parallel_0/LogicalNot" - op: "LogicalNot" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/binary_op/parallel_0_2/GreaterEqual" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/cast/parallel_0/Cast" - op: "Cast" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/logical_not/parallel_0/LogicalNot" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BOOL - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: -1000000000.0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/cast/parallel_0/Cast" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/reshape_6/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/reshape_6/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/einsum_4/parallel_0/Sum" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/reshape_6/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/binary_op_1/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/binary_op_1/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/einsum_4/parallel_0/Sum" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/binary_op_1/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/binary_op_1/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/binary_op_1/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/binary_op_1/parallel_0/transpose" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/binary_op_1/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/binary_op_1/parallel_0_1/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/binary_op_1/parallel_0_1/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/reshape_6/parallel_0/Reshape" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/binary_op_1/parallel_0_1/transpose/perm" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/binary_op_1/parallel_0_1/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 2 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/binary_op_1/parallel_0_1/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/binary_op_1/parallel_0_1/transpose" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/binary_op_1/parallel_0_1/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 1 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/binary_op_1/parallel_0_2/Equal" - op: "Equal" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/binary_op_1/parallel_0/ExpandDims" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/binary_op_1/parallel_0_1/ExpandDims" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "incompatible_shape_error" - value { - b: true - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/logical_not_1/parallel_0/LogicalNot" - op: "LogicalNot" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/binary_op_1/parallel_0_2/Equal" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/cast_1/parallel_0/Cast" - op: "Cast" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/logical_not_1/parallel_0/LogicalNot" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BOOL - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/scalar_mul_1/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: -1000000000.0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/scalar_mul_1/parallel_0/mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/cast_1/parallel_0/Cast" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/scalar_mul_1/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/negative_1/parallel_0/Neg" - op: "Neg" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add/parallel_0_2/Add" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add_1/parallel_0_1/Maximum" - op: "Maximum" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/negative_1/parallel_0/Neg" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add_1/parallel_0_1/Maximum/Enter" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add_1/parallel_0_1/Maximum/Enter" - op: "Enter" - input: "decoder/block_001/layer_000/SelfAttention/maximum/Const" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/binary_op_2/parallel_0_1/Less" - op: "Less" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add_1/parallel_0_1/Maximum" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/binary_op_2/parallel_0_1/Less/Enter" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/binary_op_2/parallel_0_1/Less/Enter" - op: "Enter" - input: "decoder/block_001/layer_000/SelfAttention/Const" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/to_float/parallel_0/Cast" - op: "Cast" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add_1/parallel_0_1/Maximum" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_INT32 - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/scalar_mul_2/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0625 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/scalar_mul_2/parallel_0/mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/to_float/parallel_0/Cast" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/scalar_mul_2/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/log/parallel_0/Log" - op: "Log" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/scalar_mul_2/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/scalar_mul_3/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.48089835047721863 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/scalar_mul_3/parallel_0/mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/log/parallel_0/Log" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/scalar_mul_3/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/scalar_mul_4/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 16.0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/scalar_mul_4/parallel_0/mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/scalar_mul_3/parallel_0/mul" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/scalar_mul_4/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/to_int32/parallel_0/Cast" - op: "Cast" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/scalar_mul_4/parallel_0/mul" - attr { - key: "DstT" - value { - type: DT_INT32 - } - } - attr { - key: "SrcT" - value { - type: DT_FLOAT - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/scalar_add/parallel_0/add/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 16 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/scalar_add/parallel_0/add" - op: "AddV2" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/to_int32/parallel_0/Cast" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/scalar_add/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add_2/parallel_0_1/Minimum" - op: "Minimum" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/scalar_add/parallel_0/add" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add_2/parallel_0_1/Minimum/Enter" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add_2/parallel_0_1/Minimum/Enter" - op: "Enter" - input: "decoder/block_001/layer_000/SelfAttention/minimum/Const" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/cast_2/parallel_0/Cast" - op: "Cast" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/binary_op_2/parallel_0_1/Less" - attr { - key: "DstT" - value { - type: DT_INT32 - } - } - attr { - key: "SrcT" - value { - type: DT_BOOL - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/einsum_3/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add_1/parallel_0_1/Maximum" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/cast_2/parallel_0/Cast" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/logical_not_2/parallel_0/LogicalNot" - op: "LogicalNot" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/binary_op_2/parallel_0_1/Less" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/cast_3/parallel_0/Cast" - op: "Cast" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/logical_not_2/parallel_0/LogicalNot" - attr { - key: "DstT" - value { - type: DT_INT32 - } - } - attr { - key: "SrcT" - value { - type: DT_BOOL - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/einsum_4/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add_2/parallel_0_1/Minimum" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/cast_3/parallel_0/Cast" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add_3/parallel_0/Add" - op: "Add" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/einsum_3/parallel_0/Mul" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/einsum_4/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/scalar_add_1/parallel_0/add/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/scalar_add_1/parallel_0/add" - op: "AddV2" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add_3/parallel_0/Add" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/scalar_add_1/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/one_hot/parallel_0/sub/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/one_hot/parallel_0/sub" - op: "Sub" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/scalar_add_1/parallel_0/add" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/one_hot/parallel_0/sub/y" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/one_hot/parallel_0/Cast/x" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/one_hot/parallel_0/Cast_1/x" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/one_hot/parallel_0/one_hot/depth" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 32 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/one_hot/parallel_0/one_hot" - op: "OneHot" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/one_hot/parallel_0/sub" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/one_hot/parallel_0/one_hot/depth" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/one_hot/parallel_0/Cast/x" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/one_hot/parallel_0/Cast_1/x" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "TI" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "axis" - value { - i: -1 - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/einsum_5/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/one_hot/parallel_0/one_hot" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/einsum_5/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abc,dc->abd" - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add_4/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add_4/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/scalar_mul/parallel_0/mul" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add_4/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add_4/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add_4/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add_4/parallel_0/transpose" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add_4/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add_4/parallel_0/ExpandDims_1/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add_4/parallel_0/ExpandDims_1" - op: "ExpandDims" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add_4/parallel_0/ExpandDims" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add_4/parallel_0/ExpandDims_1/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add_4/parallel_0_1/Add" - op: "Add" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add_4/parallel_0/ExpandDims_1" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/scalar_mul_1/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add_5/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000\003\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add_5/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add_4/parallel_0_1/Add" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add_5/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add_5/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 4 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add_5/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add_5/parallel_0/transpose" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add_5/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add_5/parallel_0_1/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\001\000\000\000\000\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add_5/parallel_0_1/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/einsum_5/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add_5/parallel_0_1/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add_5/parallel_0_1/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add_5/parallel_0_1/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add_5/parallel_0_1/transpose" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add_5/parallel_0_1/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 512 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add_5/parallel_0_1/ExpandDims_1/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add_5/parallel_0_1/ExpandDims_1" - op: "ExpandDims" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add_5/parallel_0_1/ExpandDims" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add_5/parallel_0_1/ExpandDims_1/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 512 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add_5/parallel_0_2/Add" - op: "Add" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add_5/parallel_0/ExpandDims" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add_5/parallel_0_1/ExpandDims_1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/einsum_6/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/reshape/parallel_0/Reshape" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/reshape_2/parallel_0/Reshape" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcde,abfde->abcdf" - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add_6/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 5 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 5 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000\004\000\000\000\003\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add_6/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add_5/parallel_0_2/Add" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add_6/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add_6/parallel_0_1/Add" - op: "Add" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/einsum_6/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add_6/parallel_0/transpose" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce_max/parallel_0/Max/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 4 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce_max/parallel_0/Max" - op: "Max" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add_6/parallel_0_1/Add" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce_max/parallel_0/Max/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/negative/parallel_0/Neg" - op: "Neg" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce_max/parallel_0/Max" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/add/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000\003\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/add/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/negative/parallel_0/Neg" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/add/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/add/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 4 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/add/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/add/parallel_0/transpose" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/add/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/add/parallel_0_1/Add" - op: "Add" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add_6/parallel_0_1/Add" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/add/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/exp/parallel_0/Exp" - op: "Exp" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/add/parallel_0_1/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 4 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/exp/parallel_0/Exp" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/log/parallel_0/Log" - op: "Log" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/parallel_0/Sum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/add_1/parallel_0/Add" - op: "Add" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/log/parallel_0/Log" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce_max/parallel_0/Max" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/negative/parallel_0/Neg" - op: "Neg" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/add_1/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/add/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000\003\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/add/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/negative/parallel_0/Neg" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/add/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/add/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 4 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/add/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/add/parallel_0/transpose" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/add/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/add/parallel_0_1/Add" - op: "Add" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add_6/parallel_0_1/Add" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/add/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/exp/parallel_0/Exp" - op: "Exp" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/add/parallel_0_1/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/einsum_7/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/exp/parallel_0/Exp" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/reshape_3/parallel_0/Reshape" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcde,abedf->abcdf" - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/reshape_7/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 5 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 5 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000\002\000\000\000@\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/reshape_7/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/einsum_7/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/reshape_7/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/reshape_8/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000\200\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/reshape_8/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/reshape_7/parallel_0/Reshape" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/reshape_8/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/einsum_8/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/reshape_8/parallel_0/Reshape" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/einsum_8/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,de->abce" - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/einsum_8/parallel_0/einsum/Einsum/Enter" - op: "Enter" - input: "decoder/block_001/layer_000/SelfAttention/o_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/einsum_8/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/block_000/layer_002/add/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/rms_norm/square/parallel_0/Square" - op: "Square" - input: "while_loop/while/decoder/block_001/layer_000/add/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/rms_norm/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/rms_norm/reduce_mean/reduce/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/decoder/block_001/layer_001/rms_norm/square/parallel_0/Square" - input: "while_loop/while/decoder/block_001/layer_001/rms_norm/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/rms_norm/reduce_mean/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.03125 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/rms_norm/reduce_mean/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_001/rms_norm/reduce_mean/reduce/parallel_0/Sum" - input: "while_loop/while/decoder/block_001/layer_001/rms_norm/reduce_mean/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/scalar_add/parallel_0/add/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 9.999999974752427e-07 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/scalar_add/parallel_0/add" - op: "AddV2" - input: "while_loop/while/decoder/block_001/layer_001/rms_norm/reduce_mean/scalar_mul/parallel_0/mul" - input: "while_loop/while/decoder/block_001/layer_001/scalar_add/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/rsqrt/parallel_0/Rsqrt" - op: "Rsqrt" - input: "while_loop/while/decoder/block_001/layer_001/scalar_add/parallel_0/add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/einsum/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/einsum/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_001/layer_001/rsqrt/parallel_0/Rsqrt" - input: "while_loop/while/decoder/block_001/layer_001/einsum/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/einsum/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/einsum/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_001/layer_001/einsum/parallel_0/transpose" - input: "while_loop/while/decoder/block_001/layer_001/einsum/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_000/add/parallel_0/Add" - input: "while_loop/while/decoder/block_001/layer_001/einsum/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/einsum_1/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/einsum_1/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_001/layer_001/einsum_1/parallel_0/transpose/Enter" - input: "while_loop/while/decoder/block_001/layer_001/einsum_1/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/einsum_1/parallel_0/transpose/Enter" - op: "Enter" - input: "decoder/block_001/layer_001/rms_norm/scale_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/einsum_1/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/einsum_1/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_001/layer_001/einsum_1/parallel_0/transpose" - input: "while_loop/while/decoder/block_001/layer_001/einsum_1/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/einsum_1/parallel_0/ExpandDims_1/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/einsum_1/parallel_0/ExpandDims_1" - op: "ExpandDims" - input: "while_loop/while/decoder/block_001/layer_001/einsum_1/parallel_0/ExpandDims" - input: "while_loop/while/decoder/block_001/layer_001/einsum_1/parallel_0/ExpandDims_1/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/einsum_1/parallel_0/ExpandDims_2/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 2 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/einsum_1/parallel_0/ExpandDims_2" - op: "ExpandDims" - input: "while_loop/while/decoder/block_001/layer_001/einsum_1/parallel_0/ExpandDims_1" - input: "while_loop/while/decoder/block_001/layer_001/einsum_1/parallel_0/ExpandDims_2/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_001/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_001/layer_001/einsum_1/parallel_0/ExpandDims_2" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/binary_op/parallel_0_1/NotEqual" - op: "NotEqual" - input: "while_loop/while/einsum_4/parallel_0/Sum" - input: "while_loop/while/decoder/block_001/layer_001/binary_op/parallel_0_1/NotEqual/Enter" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "incompatible_shape_error" - value { - b: true - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/binary_op/parallel_0_1/NotEqual/Enter" - op: "Enter" - input: "decoder/block_001/layer_001/Const" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/cast/parallel_0/Cast" - op: "Cast" - input: "while_loop/while/decoder/block_001/layer_001/binary_op/parallel_0_1/NotEqual" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BOOL - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/einsum_2/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/einsum_2/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_001/layer_001/cast/parallel_0/Cast" - input: "while_loop/while/decoder/block_001/layer_001/einsum_2/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/einsum_2/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/einsum_2/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_001/layer_001/einsum_2/parallel_0/transpose" - input: "while_loop/while/decoder/block_001/layer_001/einsum_2/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/einsum_2/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_001/einsum_1/parallel_0/Mul" - input: "while_loop/while/decoder/block_001/layer_001/einsum_2/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_001/layer_001/einsum_2/parallel_0/Mul" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/einsum/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,de->abce" - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/einsum/parallel_0/einsum/Einsum/Enter" - op: "Enter" - input: "decoder/block_001/layer_001/EncDecAttention/q_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/reshape/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 5 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 5 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000\002\000\000\000@\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/reshape/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/einsum/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/reshape/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/einsum_1/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/reshape_12/parallel_0/Reshape" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/einsum_1/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,de->abce" - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/einsum_1/parallel_0/einsum/Einsum/Enter" - op: "Enter" - input: "decoder/block_001/layer_001/EncDecAttention/k_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/reshape_1/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 5 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 5 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000\002\000\000\000@\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/reshape_1/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/einsum_1/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/reshape_1/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/einsum_2/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/reshape_12/parallel_0/Reshape" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/einsum_2/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,de->abce" - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/einsum_2/parallel_0/einsum/Einsum/Enter" - op: "Enter" - input: "decoder/block_001/layer_001/EncDecAttention/v_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/reshape_2/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 5 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 5 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000\002\000\000\000@\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/reshape_2/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/einsum_2/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/reshape_2/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/binary_op/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/binary_op/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/einsum_4/parallel_0/Sum" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/binary_op/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/binary_op/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/binary_op/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/binary_op/parallel_0/transpose" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/binary_op/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/binary_op/parallel_0_1/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/binary_op/parallel_0_1/transpose" - op: "Transpose" - input: "while_loop/while/reshape_13/parallel_0/Reshape" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/binary_op/parallel_0_1/transpose/perm" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/binary_op/parallel_0_1/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 2 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/binary_op/parallel_0_1/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/binary_op/parallel_0_1/transpose" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/binary_op/parallel_0_1/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 1 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/binary_op/parallel_0_2/Equal" - op: "Equal" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/binary_op/parallel_0/ExpandDims" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/binary_op/parallel_0_1/ExpandDims" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "incompatible_shape_error" - value { - b: true - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/logical_not/parallel_0/LogicalNot" - op: "LogicalNot" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/binary_op/parallel_0_2/Equal" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/cast/parallel_0/Cast" - op: "Cast" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/logical_not/parallel_0/LogicalNot" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BOOL - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: -1000000000.0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/cast/parallel_0/Cast" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/einsum_3/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/reshape/parallel_0/Reshape" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/reshape_1/parallel_0/Reshape" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcde,abfde->abcdf" - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/add/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000\003\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/add/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/scalar_mul/parallel_0/mul" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/add/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/add/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/add/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/add/parallel_0/transpose" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/add/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 1 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/add/parallel_0_1/Add" - op: "Add" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/einsum_3/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/add/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/reduce_logsumexp/reduce_max/parallel_0/Max/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 4 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/reduce_logsumexp/reduce_max/parallel_0/Max" - op: "Max" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/add/parallel_0_1/Add" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/reduce_logsumexp/reduce_max/parallel_0/Max/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/reduce_logsumexp/negative/parallel_0/Neg" - op: "Neg" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/reduce_logsumexp/reduce_max/parallel_0/Max" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/reduce_logsumexp/add/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000\003\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/reduce_logsumexp/add/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/reduce_logsumexp/negative/parallel_0/Neg" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/reduce_logsumexp/add/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/reduce_logsumexp/add/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 4 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/reduce_logsumexp/add/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/reduce_logsumexp/add/parallel_0/transpose" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/reduce_logsumexp/add/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/reduce_logsumexp/add/parallel_0_1/Add" - op: "Add" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/add/parallel_0_1/Add" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/reduce_logsumexp/add/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/reduce_logsumexp/exp/parallel_0/Exp" - op: "Exp" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/reduce_logsumexp/add/parallel_0_1/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/reduce_logsumexp/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 4 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/reduce_logsumexp/reduce/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/reduce_logsumexp/exp/parallel_0/Exp" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/reduce_logsumexp/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/reduce_logsumexp/log/parallel_0/Log" - op: "Log" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/reduce_logsumexp/reduce/parallel_0/Sum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/reduce_logsumexp/add_1/parallel_0/Add" - op: "Add" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/reduce_logsumexp/log/parallel_0/Log" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/reduce_logsumexp/reduce_max/parallel_0/Max" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/negative/parallel_0/Neg" - op: "Neg" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/reduce_logsumexp/add_1/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/add/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000\003\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/add/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/negative/parallel_0/Neg" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/add/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/add/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 4 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/add/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/add/parallel_0/transpose" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/add/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/add/parallel_0_1/Add" - op: "Add" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/add/parallel_0_1/Add" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/add/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/exp/parallel_0/Exp" - op: "Exp" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/add/parallel_0_1/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/einsum_4/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/exp/parallel_0/Exp" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/reshape_2/parallel_0/Reshape" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcde,abedf->abcdf" - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/reshape_3/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 5 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 5 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000\002\000\000\000@\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/reshape_3/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/einsum_4/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/reshape_3/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/reshape_4/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000\200\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/reshape_4/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/reshape_3/parallel_0/Reshape" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/reshape_4/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/einsum_5/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/reshape_4/parallel_0/Reshape" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/einsum_5/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,de->abce" - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/einsum_5/parallel_0/einsum/Einsum/Enter" - op: "Enter" - input: "decoder/block_001/layer_001/EncDecAttention/o_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/einsum_5/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/block_001/layer_000/add/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/rms_norm/square/parallel_0/Square" - op: "Square" - input: "while_loop/while/decoder/block_001/layer_001/add/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/rms_norm/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/rms_norm/reduce_mean/reduce/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/decoder/block_001/layer_002/rms_norm/square/parallel_0/Square" - input: "while_loop/while/decoder/block_001/layer_002/rms_norm/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/rms_norm/reduce_mean/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.03125 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/rms_norm/reduce_mean/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_002/rms_norm/reduce_mean/reduce/parallel_0/Sum" - input: "while_loop/while/decoder/block_001/layer_002/rms_norm/reduce_mean/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/scalar_add/parallel_0/add/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 9.999999974752427e-07 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/scalar_add/parallel_0/add" - op: "AddV2" - input: "while_loop/while/decoder/block_001/layer_002/rms_norm/reduce_mean/scalar_mul/parallel_0/mul" - input: "while_loop/while/decoder/block_001/layer_002/scalar_add/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/rsqrt/parallel_0/Rsqrt" - op: "Rsqrt" - input: "while_loop/while/decoder/block_001/layer_002/scalar_add/parallel_0/add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/einsum/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/einsum/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_001/layer_002/rsqrt/parallel_0/Rsqrt" - input: "while_loop/while/decoder/block_001/layer_002/einsum/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/einsum/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/einsum/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_001/layer_002/einsum/parallel_0/transpose" - input: "while_loop/while/decoder/block_001/layer_002/einsum/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_001/add/parallel_0/Add" - input: "while_loop/while/decoder/block_001/layer_002/einsum/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/einsum_1/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/einsum_1/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_001/layer_002/einsum_1/parallel_0/transpose/Enter" - input: "while_loop/while/decoder/block_001/layer_002/einsum_1/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/einsum_1/parallel_0/transpose/Enter" - op: "Enter" - input: "decoder/block_001/layer_002/rms_norm/scale_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/einsum_1/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/einsum_1/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_001/layer_002/einsum_1/parallel_0/transpose" - input: "while_loop/while/decoder/block_001/layer_002/einsum_1/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/einsum_1/parallel_0/ExpandDims_1/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/einsum_1/parallel_0/ExpandDims_1" - op: "ExpandDims" - input: "while_loop/while/decoder/block_001/layer_002/einsum_1/parallel_0/ExpandDims" - input: "while_loop/while/decoder/block_001/layer_002/einsum_1/parallel_0/ExpandDims_1/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/einsum_1/parallel_0/ExpandDims_2/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 2 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/einsum_1/parallel_0/ExpandDims_2" - op: "ExpandDims" - input: "while_loop/while/decoder/block_001/layer_002/einsum_1/parallel_0/ExpandDims_1" - input: "while_loop/while/decoder/block_001/layer_002/einsum_1/parallel_0/ExpandDims_2/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_002/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_001/layer_002/einsum_1/parallel_0/ExpandDims_2" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/binary_op/parallel_0_1/NotEqual" - op: "NotEqual" - input: "while_loop/while/einsum_4/parallel_0/Sum" - input: "while_loop/while/decoder/block_001/layer_002/binary_op/parallel_0_1/NotEqual/Enter" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "incompatible_shape_error" - value { - b: true - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/binary_op/parallel_0_1/NotEqual/Enter" - op: "Enter" - input: "decoder/block_001/layer_002/Const" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/cast/parallel_0/Cast" - op: "Cast" - input: "while_loop/while/decoder/block_001/layer_002/binary_op/parallel_0_1/NotEqual" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BOOL - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/einsum_2/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/einsum_2/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_001/layer_002/cast/parallel_0/Cast" - input: "while_loop/while/decoder/block_001/layer_002/einsum_2/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/einsum_2/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/einsum_2/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_001/layer_002/einsum_2/parallel_0/transpose" - input: "while_loop/while/decoder/block_001/layer_002/einsum_2/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/einsum_2/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_002/einsum_1/parallel_0/Mul" - input: "while_loop/while/decoder/block_001/layer_002/einsum_2/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_001/layer_002/einsum_2/parallel_0/Mul" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/einsum/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,de->abce" - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/einsum/parallel_0/einsum/Einsum/Enter" - op: "Enter" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.044714998453855515 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/einsum/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/scalar_mul/parallel_0/mul" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/einsum/parallel_0/einsum/Einsum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/einsum_2/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/einsum_1/parallel_0/Mul" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/einsum/parallel_0/einsum/Einsum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/einsum/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/einsum_2/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/scalar_mul_1/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.7978845834732056 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/scalar_mul_1/parallel_0/mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/add/parallel_0/Add" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/scalar_mul_1/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/tanh/parallel_0/Tanh" - op: "Tanh" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/scalar_mul_1/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/scalar_add/parallel_0/add/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/scalar_add/parallel_0/add" - op: "AddV2" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/tanh/parallel_0/Tanh" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/scalar_add/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/scalar_mul_2/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.5 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/scalar_mul_2/parallel_0/mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/scalar_add/parallel_0/add" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/scalar_mul_2/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/einsum_3/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/einsum/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/scalar_mul_2/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_1/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_001/layer_002/einsum_2/parallel_0/Mul" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_1/einsum/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,de->abce" - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_1/einsum/parallel_0/einsum/Einsum/Enter" - op: "Enter" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/einsum_3/parallel_0/Mul" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_1/einsum/parallel_0/einsum/Einsum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wo/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wo/einsum/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,de->abce" - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wo/einsum/parallel_0/einsum/Einsum/Enter" - op: "Enter" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wo/einsum/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/block_001/layer_001/add/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/rms_norm/square/parallel_0/Square" - op: "Square" - input: "while_loop/while/decoder/block_001/layer_002/add/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/rms_norm/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/decoder/rms_norm/reduce_mean/reduce/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/decoder/rms_norm/square/parallel_0/Square" - input: "while_loop/while/decoder/rms_norm/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/decoder/rms_norm/reduce_mean/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.03125 - } - } - } -} -node { - name: "while_loop/while/decoder/rms_norm/reduce_mean/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/decoder/rms_norm/reduce_mean/reduce/parallel_0/Sum" - input: "while_loop/while/decoder/rms_norm/reduce_mean/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/scalar_add/parallel_0/add/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 9.999999974752427e-07 - } - } - } -} -node { - name: "while_loop/while/decoder/scalar_add/parallel_0/add" - op: "AddV2" - input: "while_loop/while/decoder/rms_norm/reduce_mean/scalar_mul/parallel_0/mul" - input: "while_loop/while/decoder/scalar_add/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/rsqrt/parallel_0/Rsqrt" - op: "Rsqrt" - input: "while_loop/while/decoder/scalar_add/parallel_0/add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/einsum_1/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/einsum_1/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/rsqrt/parallel_0/Rsqrt" - input: "while_loop/while/decoder/einsum_1/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/einsum_1/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/decoder/einsum_1/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/einsum_1/parallel_0/transpose" - input: "while_loop/while/decoder/einsum_1/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_002/add/parallel_0/Add" - input: "while_loop/while/decoder/einsum_1/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/einsum_2/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/decoder/einsum_2/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/einsum_2/parallel_0/transpose/Enter" - input: "while_loop/while/decoder/einsum_2/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/einsum_2/parallel_0/transpose/Enter" - op: "Enter" - input: "decoder/rms_norm/scale_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/decoder/einsum_2/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/decoder/einsum_2/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/einsum_2/parallel_0/transpose" - input: "while_loop/while/decoder/einsum_2/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/einsum_2/parallel_0/ExpandDims_1/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "while_loop/while/decoder/einsum_2/parallel_0/ExpandDims_1" - op: "ExpandDims" - input: "while_loop/while/decoder/einsum_2/parallel_0/ExpandDims" - input: "while_loop/while/decoder/einsum_2/parallel_0/ExpandDims_1/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/einsum_2/parallel_0/ExpandDims_2/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 2 - } - } - } -} -node { - name: "while_loop/while/decoder/einsum_2/parallel_0/ExpandDims_2" - op: "ExpandDims" - input: "while_loop/while/decoder/einsum_2/parallel_0/ExpandDims_1" - input: "while_loop/while/decoder/einsum_2/parallel_0/ExpandDims_2/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/einsum_2/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/einsum_1/parallel_0/Mul" - input: "while_loop/while/decoder/einsum_2/parallel_0/ExpandDims_2" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/binary_op/parallel_0_1/NotEqual" - op: "NotEqual" - input: "while_loop/while/einsum_4/parallel_0/Sum" - input: "while_loop/while/decoder/binary_op/parallel_0_1/NotEqual/Enter" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "incompatible_shape_error" - value { - b: true - } - } -} -node { - name: "while_loop/while/decoder/binary_op/parallel_0_1/NotEqual/Enter" - op: "Enter" - input: "decoder/Const" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/decoder/cast/parallel_0/Cast" - op: "Cast" - input: "while_loop/while/decoder/binary_op/parallel_0_1/NotEqual" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BOOL - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/einsum_3/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/einsum_3/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/cast/parallel_0/Cast" - input: "while_loop/while/decoder/einsum_3/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/einsum_3/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/decoder/einsum_3/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/einsum_3/parallel_0/transpose" - input: "while_loop/while/decoder/einsum_3/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/einsum_3/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/einsum_2/parallel_0/Mul" - input: "while_loop/while/decoder/einsum_3/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/logits/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/einsum_3/parallel_0/Mul" - input: "while_loop/while/decoder/logits/einsum/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32128 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,de->abce" - } - } -} -node { - name: "while_loop/while/decoder/logits/einsum/parallel_0/einsum/Einsum/Enter" - op: "Enter" - input: "decoder/logits/kernel_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 32128 - } - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/decoder/add/parallel_0_1/Maximum" - op: "Maximum" - input: "while_loop/while/einsum_3/parallel_0/Sum" - input: "while_loop/while/decoder/add/parallel_0_1/Maximum/Enter" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/add/parallel_0_1/Maximum/Enter" - op: "Enter" - input: "decoder/maximum/Const" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/decoder/one_hot_1/parallel_0/sub/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/decoder/one_hot_1/parallel_0/sub" - op: "Sub" - input: "while_loop/while/decoder/add/parallel_0_1/Maximum" - input: "while_loop/while/decoder/one_hot_1/parallel_0/sub/y" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/one_hot_1/parallel_0/Cast/x" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "while_loop/while/decoder/one_hot_1/parallel_0/Cast_1/x" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/while/decoder/one_hot_1/parallel_0/one_hot/depth" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 32128 - } - } - } -} -node { - name: "while_loop/while/decoder/one_hot_1/parallel_0/one_hot" - op: "OneHot" - input: "while_loop/while/decoder/one_hot_1/parallel_0/sub" - input: "while_loop/while/decoder/one_hot_1/parallel_0/one_hot/depth" - input: "while_loop/while/decoder/one_hot_1/parallel_0/Cast/x" - input: "while_loop/while/decoder/one_hot_1/parallel_0/Cast_1/x" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "TI" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32128 - } - } - } - } - } - attr { - key: "axis" - value { - i: -1 - } - } -} -node { - name: "while_loop/while/decoder/reduce_logsumexp/reduce_max/parallel_0/Max/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/decoder/reduce_logsumexp/reduce_max/parallel_0/Max" - op: "Max" - input: "while_loop/while/decoder/logits/einsum/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/reduce_logsumexp/reduce_max/parallel_0/Max/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/decoder/reduce_logsumexp/negative/parallel_0/Neg" - op: "Neg" - input: "while_loop/while/decoder/reduce_logsumexp/reduce_max/parallel_0/Max" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/reduce_logsumexp/add/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/reduce_logsumexp/add/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/reduce_logsumexp/negative/parallel_0/Neg" - input: "while_loop/while/decoder/reduce_logsumexp/add/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/reduce_logsumexp/add/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/decoder/reduce_logsumexp/add/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/reduce_logsumexp/add/parallel_0/transpose" - input: "while_loop/while/decoder/reduce_logsumexp/add/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/reduce_logsumexp/add/parallel_0_1/Add" - op: "Add" - input: "while_loop/while/decoder/logits/einsum/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/reduce_logsumexp/add/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32128 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/reduce_logsumexp/exp/parallel_0/Exp" - op: "Exp" - input: "while_loop/while/decoder/reduce_logsumexp/add/parallel_0_1/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32128 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/reduce_logsumexp/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/decoder/reduce_logsumexp/reduce/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/decoder/reduce_logsumexp/exp/parallel_0/Exp" - input: "while_loop/while/decoder/reduce_logsumexp/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/decoder/reduce_logsumexp/log/parallel_0/Log" - op: "Log" - input: "while_loop/while/decoder/reduce_logsumexp/reduce/parallel_0/Sum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/reduce_logsumexp/add_1/parallel_0/Add" - op: "Add" - input: "while_loop/while/decoder/reduce_logsumexp/log/parallel_0/Log" - input: "while_loop/while/decoder/reduce_logsumexp/reduce_max/parallel_0/Max" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/negative/parallel_0/Neg" - op: "Neg" - input: "while_loop/while/decoder/reduce_logsumexp/add_1/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/add_1/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/add_1/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/negative/parallel_0/Neg" - input: "while_loop/while/decoder/add_1/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/add_1/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/decoder/add_1/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/add_1/parallel_0/transpose" - input: "while_loop/while/decoder/add_1/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/add_1/parallel_0_1/Add" - op: "Add" - input: "while_loop/while/decoder/logits/einsum/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/add_1/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32128 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/einsum_4/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/add_1/parallel_0_1/Add" - input: "while_loop/while/decoder/one_hot_1/parallel_0/one_hot" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32128 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/decoder/reduce/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/decoder/einsum_4/parallel_0/Mul" - input: "while_loop/while/decoder/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/decoder/negative_1/parallel_0/Neg" - op: "Neg" - input: "while_loop/while/decoder/reduce/parallel_0/Sum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/square/parallel_0/Square" - op: "Square" - input: "while_loop/while/decoder/reduce_logsumexp/add_1/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 9.999999747378752e-05 - } - } - } -} -node { - name: "while_loop/while/decoder/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/decoder/square/parallel_0/Square" - input: "while_loop/while/decoder/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/add_2/parallel_0/Add" - op: "Add" - input: "while_loop/while/decoder/negative_1/parallel_0/Neg" - input: "while_loop/while/decoder/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/binary_op_1/parallel_0_1/Greater" - op: "Greater" - input: "while_loop/while/einsum_3/parallel_0/Sum" - input: "while_loop/while/decoder/binary_op_1/parallel_0_1/Greater/Enter" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/binary_op_1/parallel_0_1/Greater/Enter" - op: "Enter" - input: "decoder/Const_1" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/decoder/cast_1/parallel_0/Cast" - op: "Cast" - input: "while_loop/while/decoder/binary_op_1/parallel_0_1/Greater" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_BOOL - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/einsum_5/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/add_2/parallel_0/Add" - input: "while_loop/while/decoder/cast_1/parallel_0/Cast" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/reduce_1/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/reduce_1/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/decoder/einsum_5/parallel_0/Mul" - input: "while_loop/while/decoder/reduce_1/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/decoder/scalar_mul_1/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 6.103515625e-05 - } - } - } -} -node { - name: "while_loop/while/decoder/scalar_mul_1/parallel_0/mul" - op: "Mul" - input: "while_loop/while/decoder/reduce_1/parallel_0/Sum" - input: "while_loop/while/decoder/scalar_mul_1/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "while_loop/while/scalar_add/parallel_0/add/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/while/scalar_add/parallel_0/add" - op: "AddV2" - input: "while_loop/while/decoder/scalar_mul_1/parallel_0/mul" - input: "while_loop/while/scalar_add/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "while_loop/while/constant_1/parallel_0/Const" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "while_loop/while/decoder/scalar_mul_1/gradients/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 6.103515625e-05 - } - } - } -} -node { - name: "while_loop/while/decoder/scalar_mul_1/gradients/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/constant_1/parallel_0/Const" - input: "while_loop/while/decoder/scalar_mul_1/gradients/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "while_loop/while/decoder/reduce_1/gradients/broadcast/parallel_0/zeros/shape_as_tensor" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/reduce_1/gradients/broadcast/parallel_0/zeros/Const" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/while/decoder/reduce_1/gradients/broadcast/parallel_0/zeros" - op: "Fill" - input: "while_loop/while/decoder/reduce_1/gradients/broadcast/parallel_0/zeros/shape_as_tensor" - input: "while_loop/while/decoder/reduce_1/gradients/broadcast/parallel_0/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "while_loop/while/decoder/reduce_1/gradients/broadcast/parallel_0/add" - op: "AddV2" - input: "while_loop/while/decoder/reduce_1/gradients/broadcast/parallel_0/zeros" - input: "while_loop/while/decoder/scalar_mul_1/gradients/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/einsum_5/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/reduce_1/gradients/broadcast/parallel_0/add" - input: "while_loop/while/decoder/cast_1/parallel_0/Cast" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/einsum_5/gradients/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/reduce_1/gradients/broadcast/parallel_0/add" - input: "while_loop/while/decoder/add_2/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/scalar_mul/gradients/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 9.999999747378752e-05 - } - } - } -} -node { - name: "while_loop/while/decoder/scalar_mul/gradients/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/decoder/einsum_5/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/scalar_mul/gradients/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/square/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/scalar_mul/gradients/scalar_mul/parallel_0/mul" - input: "while_loop/while/decoder/reduce_logsumexp/add_1/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/square/gradients/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 2.0 - } - } - } -} -node { - name: "while_loop/while/decoder/square/gradients/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/decoder/square/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/square/gradients/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/negative_1/gradients/negative/parallel_0/Neg" - op: "Neg" - input: "while_loop/while/decoder/einsum_5/gradients/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/reduce/gradients/broadcast/parallel_0/zeros/shape_as_tensor" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000\200}\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/reduce/gradients/broadcast/parallel_0/zeros/Const" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/while/decoder/reduce/gradients/broadcast/parallel_0/zeros" - op: "Fill" - input: "while_loop/while/decoder/reduce/gradients/broadcast/parallel_0/zeros/shape_as_tensor" - input: "while_loop/while/decoder/reduce/gradients/broadcast/parallel_0/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32128 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "while_loop/while/decoder/reduce/gradients/broadcast/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/reduce/gradients/broadcast/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/negative_1/gradients/negative/parallel_0/Neg" - input: "while_loop/while/decoder/reduce/gradients/broadcast/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/reduce/gradients/broadcast/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/decoder/reduce/gradients/broadcast/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/reduce/gradients/broadcast/parallel_0/transpose" - input: "while_loop/while/decoder/reduce/gradients/broadcast/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/reduce/gradients/broadcast/parallel_0/add" - op: "AddV2" - input: "while_loop/while/decoder/reduce/gradients/broadcast/parallel_0/zeros" - input: "while_loop/while/decoder/reduce/gradients/broadcast/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32128 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/einsum_4/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/reduce/gradients/broadcast/parallel_0/add" - input: "while_loop/while/decoder/one_hot_1/parallel_0/one_hot" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32128 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/einsum_4/gradients/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/reduce/gradients/broadcast/parallel_0/add" - input: "while_loop/while/decoder/add_1/parallel_0_1/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32128 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/add_1/gradients/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/decoder/add_1/gradients/reduce/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/decoder/einsum_4/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/add_1/gradients/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/decoder/negative/gradients/negative/parallel_0/Neg" - op: "Neg" - input: "while_loop/while/decoder/add_1/gradients/reduce/parallel_0/Sum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/negative/gradients/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/decoder/square/gradients/scalar_mul/parallel_0/mul" - input: "while_loop/while/decoder/negative/gradients/negative/parallel_0/Neg" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/reduce_logsumexp/log/gradients/reciprocal/parallel_0/Reciprocal" - op: "Reciprocal" - input: "while_loop/while/decoder/reduce_logsumexp/reduce/parallel_0/Sum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/reduce_logsumexp/log/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/negative/gradients/add/parallel_0/Add" - input: "while_loop/while/decoder/reduce_logsumexp/log/gradients/reciprocal/parallel_0/Reciprocal" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/zeros/shape_as_tensor" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000\200}\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/zeros/Const" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/while/decoder/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/zeros" - op: "Fill" - input: "while_loop/while/decoder/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/zeros/shape_as_tensor" - input: "while_loop/while/decoder/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32128 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "while_loop/while/decoder/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/reduce_logsumexp/log/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/decoder/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/transpose" - input: "while_loop/while/decoder/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/add" - op: "AddV2" - input: "while_loop/while/decoder/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/zeros" - input: "while_loop/while/decoder/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32128 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/reduce_logsumexp/exp/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/add" - input: "while_loop/while/decoder/reduce_logsumexp/exp/parallel_0/Exp" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32128 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/reduce_logsumexp/add/gradients/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/decoder/reduce_logsumexp/add/gradients/reduce/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/decoder/reduce_logsumexp/exp/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/reduce_logsumexp/add/gradients/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/decoder/reduce_logsumexp/add/gradients/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/decoder/einsum_4/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/reduce_logsumexp/exp/gradients/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32128 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/logits/einsum/gradients/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/reduce_logsumexp/add/gradients/add/parallel_0/Add" - input: "while_loop/while/decoder/logits/einsum/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,ed->abce" - } - } -} -node { - name: "while_loop/while/decoder/logits/einsum/gradients/einsum_1/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/reduce_logsumexp/add/gradients/add/parallel_0/Add" - input: "while_loop/while/decoder/einsum_3/parallel_0/Mul" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 32128 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,abce->ed" - } - } -} -node { - name: "while_loop/while/decoder/einsum_3/gradients/einsum/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/einsum_3/gradients/einsum/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/cast/parallel_0/Cast" - input: "while_loop/while/decoder/einsum_3/gradients/einsum/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/einsum_3/gradients/einsum/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/decoder/einsum_3/gradients/einsum/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/einsum_3/gradients/einsum/parallel_0/transpose" - input: "while_loop/while/decoder/einsum_3/gradients/einsum/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/einsum_3/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/logits/einsum/gradients/einsum/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/einsum_3/gradients/einsum/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/einsum_3/gradients/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/logits/einsum/gradients/einsum/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/einsum_2/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/einsum_3/gradients/einsum_1/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/decoder/einsum_3/gradients/einsum_1/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/decoder/einsum_3/gradients/einsum_1/parallel_0/Mul" - input: "while_loop/while/decoder/einsum_3/gradients/einsum_1/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/decoder/einsum_2/gradients/einsum/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/decoder/einsum_2/gradients/einsum/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/einsum_2/parallel_0/transpose/Enter" - input: "while_loop/while/decoder/einsum_2/gradients/einsum/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/einsum_2/gradients/einsum/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/decoder/einsum_2/gradients/einsum/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/einsum_2/gradients/einsum/parallel_0/transpose" - input: "while_loop/while/decoder/einsum_2/gradients/einsum/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/einsum_2/gradients/einsum/parallel_0/ExpandDims_1/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "while_loop/while/decoder/einsum_2/gradients/einsum/parallel_0/ExpandDims_1" - op: "ExpandDims" - input: "while_loop/while/decoder/einsum_2/gradients/einsum/parallel_0/ExpandDims" - input: "while_loop/while/decoder/einsum_2/gradients/einsum/parallel_0/ExpandDims_1/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/einsum_2/gradients/einsum/parallel_0/ExpandDims_2/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 2 - } - } - } -} -node { - name: "while_loop/while/decoder/einsum_2/gradients/einsum/parallel_0/ExpandDims_2" - op: "ExpandDims" - input: "while_loop/while/decoder/einsum_2/gradients/einsum/parallel_0/ExpandDims_1" - input: "while_loop/while/decoder/einsum_2/gradients/einsum/parallel_0/ExpandDims_2/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/einsum_2/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/einsum_3/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/einsum_2/gradients/einsum/parallel_0/ExpandDims_2" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/einsum_2/gradients/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/einsum_3/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/einsum_1/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/einsum_2/gradients/einsum_1/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/einsum_2/gradients/einsum_1/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/decoder/einsum_2/gradients/einsum_1/parallel_0/Mul" - input: "while_loop/while/decoder/einsum_2/gradients/einsum_1/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/decoder/einsum_1/gradients/einsum/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/einsum_1/gradients/einsum/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/rsqrt/parallel_0/Rsqrt" - input: "while_loop/while/decoder/einsum_1/gradients/einsum/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/einsum_1/gradients/einsum/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/decoder/einsum_1/gradients/einsum/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/einsum_1/gradients/einsum/parallel_0/transpose" - input: "while_loop/while/decoder/einsum_1/gradients/einsum/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/einsum_1/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/einsum_2/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/einsum_1/gradients/einsum/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/einsum_1/gradients/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/einsum_2/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_001/layer_002/add/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/einsum_1/gradients/einsum_1/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/decoder/einsum_1/gradients/einsum_1/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/decoder/einsum_1/gradients/einsum_1/parallel_0/Mul" - input: "while_loop/while/decoder/einsum_1/gradients/einsum_1/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/decoder/rsqrt/gradients/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: -0.5 - } - } - } -} -node { - name: "while_loop/while/decoder/rsqrt/gradients/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/decoder/einsum_1/gradients/einsum_1/parallel_0/Sum" - input: "while_loop/while/decoder/rsqrt/gradients/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/rsqrt/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/rsqrt/gradients/scalar_mul/parallel_0/mul" - input: "while_loop/while/decoder/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/rsqrt/gradients/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/rsqrt/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/rsqrt/gradients/einsum_2/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/rsqrt/gradients/einsum_1/parallel_0/Mul" - input: "while_loop/while/decoder/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/rms_norm/reduce_mean/scalar_mul/gradients/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.03125 - } - } - } -} -node { - name: "while_loop/while/decoder/rms_norm/reduce_mean/scalar_mul/gradients/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/decoder/rsqrt/gradients/einsum_2/parallel_0/Mul" - input: "while_loop/while/decoder/rms_norm/reduce_mean/scalar_mul/gradients/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/zeros/shape_as_tensor" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000 \000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/zeros/Const" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/while/decoder/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/zeros" - op: "Fill" - input: "while_loop/while/decoder/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/zeros/shape_as_tensor" - input: "while_loop/while/decoder/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "while_loop/while/decoder/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/rms_norm/reduce_mean/scalar_mul/gradients/scalar_mul/parallel_0/mul" - input: "while_loop/while/decoder/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/decoder/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/transpose" - input: "while_loop/while/decoder/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/add" - op: "AddV2" - input: "while_loop/while/decoder/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/zeros" - input: "while_loop/while/decoder/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/rms_norm/square/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/add" - input: "while_loop/while/decoder/block_001/layer_002/add/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/rms_norm/square/gradients/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 2.0 - } - } - } -} -node { - name: "while_loop/while/decoder/rms_norm/square/gradients/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/decoder/rms_norm/square/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/rms_norm/square/gradients/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/rms_norm/square/gradients/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/decoder/einsum_1/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/rms_norm/square/gradients/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wo/einsum/gradients/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/rms_norm/square/gradients/add/parallel_0/Add" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wo/einsum/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,ed->abce" - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wo/einsum/gradients/einsum_1/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/rms_norm/square/gradients/add/parallel_0/Add" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/einsum/parallel_0/Mul" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,abce->ed" - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/einsum/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wo/einsum/gradients/einsum/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_1/einsum/parallel_0/einsum/Einsum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/einsum/gradients/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wo/einsum/gradients/einsum/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/einsum_3/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_1/einsum/gradients/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/einsum/gradients/einsum_1/parallel_0/Mul" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_1/einsum/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,ed->abce" - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_1/einsum/gradients/einsum_1/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/einsum/gradients/einsum_1/parallel_0/Mul" - input: "while_loop/while/decoder/block_001/layer_002/einsum_2/parallel_0/Mul" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,abce->ed" - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/einsum_3/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/einsum/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/scalar_mul_2/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/einsum_3/gradients/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/einsum/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/einsum/parallel_0/einsum/Einsum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/scalar_mul_2/gradients/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.5 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/scalar_mul_2/gradients/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/einsum_3/gradients/einsum_1/parallel_0/Mul" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/scalar_mul_2/gradients/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/tanh/gradients/square/parallel_0/Square" - op: "Square" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/tanh/parallel_0/Tanh" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/tanh/gradients/negative/parallel_0/Neg" - op: "Neg" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/tanh/gradients/square/parallel_0/Square" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/tanh/gradients/add/parallel_0_1/Add" - op: "Add" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/tanh/gradients/add/parallel_0_1/Add/Enter" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/tanh/gradients/negative/parallel_0/Neg" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/tanh/gradients/add/parallel_0_1/Add/Enter" - op: "Enter" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/tanh/gradients/sub/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/tanh/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/tanh/gradients/add/parallel_0_1/Add" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/scalar_mul_2/gradients/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/scalar_mul_1/gradients/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.7978845834732056 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/scalar_mul_1/gradients/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/tanh/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/scalar_mul_1/gradients/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/add/gradients/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/einsum_3/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/scalar_mul_1/gradients/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/einsum_2/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/scalar_mul_1/gradients/scalar_mul/parallel_0/mul" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/einsum/parallel_0/einsum/Einsum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/einsum_2/gradients/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/scalar_mul_1/gradients/scalar_mul/parallel_0/mul" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/einsum_1/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/einsum_2/gradients/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/add/gradients/add/parallel_0/Add" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/einsum_2/gradients/einsum_1/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/einsum_1/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/einsum_2/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/einsum/parallel_0/einsum/Einsum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/einsum_1/gradients/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/einsum_2/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/einsum_1/gradients/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/einsum_2/gradients/add/parallel_0/Add" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/einsum_1/gradients/einsum_1/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/scalar_mul/gradients/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.044714998453855515 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/scalar_mul/gradients/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/einsum_1/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/scalar_mul/gradients/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/scalar_mul/gradients/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/einsum_1/gradients/add/parallel_0/Add" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/scalar_mul/gradients/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/einsum/gradients/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/scalar_mul/gradients/add/parallel_0/Add" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/einsum/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,ed->abce" - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/einsum/gradients/einsum_1/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/scalar_mul/gradients/add/parallel_0/Add" - input: "while_loop/while/decoder/block_001/layer_002/einsum_2/parallel_0/Mul" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,abce->ed" - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/einsum/gradients/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_1/einsum/gradients/einsum/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/einsum/gradients/einsum/parallel_0/einsum/Einsum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/einsum_2/gradients/einsum/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/einsum_2/gradients/einsum/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_001/layer_002/cast/parallel_0/Cast" - input: "while_loop/while/decoder/block_001/layer_002/einsum_2/gradients/einsum/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/einsum_2/gradients/einsum/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/einsum_2/gradients/einsum/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_001/layer_002/einsum_2/gradients/einsum/parallel_0/transpose" - input: "while_loop/while/decoder/block_001/layer_002/einsum_2/gradients/einsum/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/einsum_2/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/einsum/gradients/add/parallel_0/Add" - input: "while_loop/while/decoder/block_001/layer_002/einsum_2/gradients/einsum/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/einsum_2/gradients/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/einsum/gradients/add/parallel_0/Add" - input: "while_loop/while/decoder/block_001/layer_002/einsum_1/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/einsum_2/gradients/einsum_1/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/einsum_2/gradients/einsum_1/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/decoder/block_001/layer_002/einsum_2/gradients/einsum_1/parallel_0/Mul" - input: "while_loop/while/decoder/block_001/layer_002/einsum_2/gradients/einsum_1/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/einsum_1/gradients/einsum/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/einsum_1/gradients/einsum/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_001/layer_002/einsum_1/parallel_0/transpose/Enter" - input: "while_loop/while/decoder/block_001/layer_002/einsum_1/gradients/einsum/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/einsum_1/gradients/einsum/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/einsum_1/gradients/einsum/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_001/layer_002/einsum_1/gradients/einsum/parallel_0/transpose" - input: "while_loop/while/decoder/block_001/layer_002/einsum_1/gradients/einsum/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/einsum_1/gradients/einsum/parallel_0/ExpandDims_1/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/einsum_1/gradients/einsum/parallel_0/ExpandDims_1" - op: "ExpandDims" - input: "while_loop/while/decoder/block_001/layer_002/einsum_1/gradients/einsum/parallel_0/ExpandDims" - input: "while_loop/while/decoder/block_001/layer_002/einsum_1/gradients/einsum/parallel_0/ExpandDims_1/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/einsum_1/gradients/einsum/parallel_0/ExpandDims_2/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 2 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/einsum_1/gradients/einsum/parallel_0/ExpandDims_2" - op: "ExpandDims" - input: "while_loop/while/decoder/block_001/layer_002/einsum_1/gradients/einsum/parallel_0/ExpandDims_1" - input: "while_loop/while/decoder/block_001/layer_002/einsum_1/gradients/einsum/parallel_0/ExpandDims_2/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/einsum_1/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_002/einsum_2/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_001/layer_002/einsum_1/gradients/einsum/parallel_0/ExpandDims_2" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/einsum_1/gradients/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_002/einsum_2/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_001/layer_002/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/einsum_1/gradients/einsum_1/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/einsum_1/gradients/einsum_1/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/decoder/block_001/layer_002/einsum_1/gradients/einsum_1/parallel_0/Mul" - input: "while_loop/while/decoder/block_001/layer_002/einsum_1/gradients/einsum_1/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/einsum/gradients/einsum/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/einsum/gradients/einsum/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_001/layer_002/rsqrt/parallel_0/Rsqrt" - input: "while_loop/while/decoder/block_001/layer_002/einsum/gradients/einsum/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/einsum/gradients/einsum/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/einsum/gradients/einsum/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_001/layer_002/einsum/gradients/einsum/parallel_0/transpose" - input: "while_loop/while/decoder/block_001/layer_002/einsum/gradients/einsum/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/einsum/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_002/einsum_1/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_001/layer_002/einsum/gradients/einsum/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/einsum/gradients/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_002/einsum_1/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_001/layer_001/add/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/einsum/gradients/einsum_1/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/einsum/gradients/einsum_1/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/decoder/block_001/layer_002/einsum/gradients/einsum_1/parallel_0/Mul" - input: "while_loop/while/decoder/block_001/layer_002/einsum/gradients/einsum_1/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/einsum/gradients/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/decoder/rms_norm/square/gradients/add/parallel_0/Add" - input: "while_loop/while/decoder/block_001/layer_002/einsum/gradients/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/rsqrt/gradients/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: -0.5 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/rsqrt/gradients/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_002/einsum/gradients/einsum_1/parallel_0/Sum" - input: "while_loop/while/decoder/block_001/layer_002/rsqrt/gradients/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/rsqrt/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_002/rsqrt/gradients/scalar_mul/parallel_0/mul" - input: "while_loop/while/decoder/block_001/layer_002/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/rsqrt/gradients/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_002/rsqrt/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_001/layer_002/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/rsqrt/gradients/einsum_2/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_002/rsqrt/gradients/einsum_1/parallel_0/Mul" - input: "while_loop/while/decoder/block_001/layer_002/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/rms_norm/reduce_mean/scalar_mul/gradients/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.03125 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/rms_norm/reduce_mean/scalar_mul/gradients/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_002/rsqrt/gradients/einsum_2/parallel_0/Mul" - input: "while_loop/while/decoder/block_001/layer_002/rms_norm/reduce_mean/scalar_mul/gradients/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/zeros/shape_as_tensor" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000 \000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/zeros/Const" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/zeros" - op: "Fill" - input: "while_loop/while/decoder/block_001/layer_002/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/zeros/shape_as_tensor" - input: "while_loop/while/decoder/block_001/layer_002/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_001/layer_002/rms_norm/reduce_mean/scalar_mul/gradients/scalar_mul/parallel_0/mul" - input: "while_loop/while/decoder/block_001/layer_002/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_001/layer_002/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/transpose" - input: "while_loop/while/decoder/block_001/layer_002/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/add" - op: "AddV2" - input: "while_loop/while/decoder/block_001/layer_002/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/zeros" - input: "while_loop/while/decoder/block_001/layer_002/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/rms_norm/square/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_002/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/add" - input: "while_loop/while/decoder/block_001/layer_001/add/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/rms_norm/square/gradients/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 2.0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/rms_norm/square/gradients/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_002/rms_norm/square/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_001/layer_002/rms_norm/square/gradients/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_002/rms_norm/square/gradients/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/decoder/block_001/layer_002/einsum/gradients/add/parallel_0/Add" - input: "while_loop/while/decoder/block_001/layer_002/rms_norm/square/gradients/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/einsum_5/gradients/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_001/layer_002/rms_norm/square/gradients/add/parallel_0/Add" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/einsum_5/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,ed->abce" - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/einsum_5/gradients/einsum_1/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_001/layer_002/rms_norm/square/gradients/add/parallel_0/Add" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/reshape_4/parallel_0/Reshape" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,abce->ed" - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/reshape_4/gradients/reshape/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 5 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 5 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000\002\000\000\000@\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/reshape_4/gradients/reshape/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/einsum_5/gradients/einsum/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/reshape_4/gradients/reshape/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/reshape_3/gradients/reshape/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 5 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 5 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000\002\000\000\000@\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/reshape_3/gradients/reshape/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/reshape_4/gradients/reshape/parallel_0/Reshape" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/reshape_3/gradients/reshape/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/einsum_4/gradients/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/reshape_3/gradients/reshape/parallel_0/Reshape" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/reshape_2/parallel_0/Reshape" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcde,abfde->abcdf" - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/einsum_4/gradients/einsum_1/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/reshape_3/gradients/reshape/parallel_0/Reshape" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/exp/parallel_0/Exp" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcde,abcdf->abfde" - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/exp/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/einsum_4/gradients/einsum/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/exp/parallel_0/Exp" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/add/gradients/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 4 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/add/gradients/reduce/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/exp/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/add/gradients/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/negative/gradients/negative/parallel_0/Neg" - op: "Neg" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/add/gradients/reduce/parallel_0/Sum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/reduce_logsumexp/log/gradients/reciprocal/parallel_0/Reciprocal" - op: "Reciprocal" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/reduce_logsumexp/reduce/parallel_0/Sum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/reduce_logsumexp/log/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/negative/gradients/negative/parallel_0/Neg" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/reduce_logsumexp/log/gradients/reciprocal/parallel_0/Reciprocal" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/zeros/shape_as_tensor" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 5 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 5 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000\002\000\000\000\000\002\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/zeros/Const" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/zeros" - op: "Fill" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/zeros/shape_as_tensor" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000\003\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/reduce_logsumexp/log/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 4 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/transpose" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/add" - op: "AddV2" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/zeros" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/reduce_logsumexp/exp/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/add" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/reduce_logsumexp/exp/parallel_0/Exp" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/reduce_logsumexp/add/gradients/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 4 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/reduce_logsumexp/add/gradients/reduce/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/reduce_logsumexp/exp/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/reduce_logsumexp/add/gradients/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/reduce_logsumexp/add/gradients/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/exp/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/reduce_logsumexp/exp/gradients/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/add/gradients/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/add/gradients/reduce/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/reduce_logsumexp/add/gradients/add/parallel_0/Add" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/add/gradients/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/einsum_3/gradients/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/reduce_logsumexp/add/gradients/add/parallel_0/Add" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/reshape_1/parallel_0/Reshape" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcde,abedf->abcdf" - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/einsum_3/gradients/einsum_1/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/softmax/reduce_logsumexp/add/gradients/add/parallel_0/Add" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/reshape/parallel_0/Reshape" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcde,abcdf->abedf" - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/reshape_2/gradients/reshape/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000\200\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/reshape_2/gradients/reshape/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/einsum_4/gradients/einsum_1/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/reshape_2/gradients/reshape/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/einsum_2/gradients/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/reshape_2/gradients/reshape/parallel_0/Reshape" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/einsum_2/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,ed->abce" - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/einsum_2/gradients/einsum_1/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/reshape_2/gradients/reshape/parallel_0/Reshape" - input: "while_loop/while/reshape_12/parallel_0/Reshape" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,abce->ed" - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/reshape_1/gradients/reshape/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000\200\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/reshape_1/gradients/reshape/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/einsum_3/gradients/einsum_1/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/reshape_1/gradients/reshape/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/einsum_1/gradients/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/reshape_1/gradients/reshape/parallel_0/Reshape" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/einsum_1/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,ed->abce" - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/einsum_1/gradients/einsum_1/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/reshape_1/gradients/reshape/parallel_0/Reshape" - input: "while_loop/while/reshape_12/parallel_0/Reshape" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,abce->ed" - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/einsum_1/gradients/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/einsum_2/gradients/einsum/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/einsum_1/gradients/einsum/parallel_0/einsum/Einsum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/reshape/gradients/reshape/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000\200\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/reshape/gradients/reshape/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/einsum_3/gradients/einsum/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/reshape/gradients/reshape/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/einsum/gradients/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/reshape/gradients/reshape/parallel_0/Reshape" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/einsum/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,ed->abce" - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/einsum/gradients/einsum_1/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/reshape/gradients/reshape/parallel_0/Reshape" - input: "while_loop/while/decoder/block_001/layer_001/einsum_2/parallel_0/Mul" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,abce->ed" - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/einsum_2/gradients/einsum/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/einsum_2/gradients/einsum/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_001/layer_001/cast/parallel_0/Cast" - input: "while_loop/while/decoder/block_001/layer_001/einsum_2/gradients/einsum/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/einsum_2/gradients/einsum/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/einsum_2/gradients/einsum/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_001/layer_001/einsum_2/gradients/einsum/parallel_0/transpose" - input: "while_loop/while/decoder/block_001/layer_001/einsum_2/gradients/einsum/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/einsum_2/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/einsum/gradients/einsum/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/block_001/layer_001/einsum_2/gradients/einsum/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/einsum_2/gradients/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/einsum/gradients/einsum/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/block_001/layer_001/einsum_1/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/einsum_2/gradients/einsum_1/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/einsum_2/gradients/einsum_1/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/decoder/block_001/layer_001/einsum_2/gradients/einsum_1/parallel_0/Mul" - input: "while_loop/while/decoder/block_001/layer_001/einsum_2/gradients/einsum_1/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/einsum_1/gradients/einsum/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/einsum_1/gradients/einsum/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_001/layer_001/einsum_1/parallel_0/transpose/Enter" - input: "while_loop/while/decoder/block_001/layer_001/einsum_1/gradients/einsum/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/einsum_1/gradients/einsum/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/einsum_1/gradients/einsum/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_001/layer_001/einsum_1/gradients/einsum/parallel_0/transpose" - input: "while_loop/while/decoder/block_001/layer_001/einsum_1/gradients/einsum/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/einsum_1/gradients/einsum/parallel_0/ExpandDims_1/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/einsum_1/gradients/einsum/parallel_0/ExpandDims_1" - op: "ExpandDims" - input: "while_loop/while/decoder/block_001/layer_001/einsum_1/gradients/einsum/parallel_0/ExpandDims" - input: "while_loop/while/decoder/block_001/layer_001/einsum_1/gradients/einsum/parallel_0/ExpandDims_1/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/einsum_1/gradients/einsum/parallel_0/ExpandDims_2/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 2 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/einsum_1/gradients/einsum/parallel_0/ExpandDims_2" - op: "ExpandDims" - input: "while_loop/while/decoder/block_001/layer_001/einsum_1/gradients/einsum/parallel_0/ExpandDims_1" - input: "while_loop/while/decoder/block_001/layer_001/einsum_1/gradients/einsum/parallel_0/ExpandDims_2/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/einsum_1/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_001/einsum_2/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_001/layer_001/einsum_1/gradients/einsum/parallel_0/ExpandDims_2" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/einsum_1/gradients/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_001/einsum_2/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_001/layer_001/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/einsum_1/gradients/einsum_1/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/einsum_1/gradients/einsum_1/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/decoder/block_001/layer_001/einsum_1/gradients/einsum_1/parallel_0/Mul" - input: "while_loop/while/decoder/block_001/layer_001/einsum_1/gradients/einsum_1/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/einsum/gradients/einsum/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/einsum/gradients/einsum/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_001/layer_001/rsqrt/parallel_0/Rsqrt" - input: "while_loop/while/decoder/block_001/layer_001/einsum/gradients/einsum/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/einsum/gradients/einsum/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/einsum/gradients/einsum/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_001/layer_001/einsum/gradients/einsum/parallel_0/transpose" - input: "while_loop/while/decoder/block_001/layer_001/einsum/gradients/einsum/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/einsum/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_001/einsum_1/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_001/layer_001/einsum/gradients/einsum/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/einsum/gradients/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_001/einsum_1/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_001/layer_000/add/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/einsum/gradients/einsum_1/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/einsum/gradients/einsum_1/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/decoder/block_001/layer_001/einsum/gradients/einsum_1/parallel_0/Mul" - input: "while_loop/while/decoder/block_001/layer_001/einsum/gradients/einsum_1/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/einsum/gradients/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/decoder/block_001/layer_002/rms_norm/square/gradients/add/parallel_0/Add" - input: "while_loop/while/decoder/block_001/layer_001/einsum/gradients/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/rsqrt/gradients/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: -0.5 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/rsqrt/gradients/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_001/einsum/gradients/einsum_1/parallel_0/Sum" - input: "while_loop/while/decoder/block_001/layer_001/rsqrt/gradients/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/rsqrt/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_001/rsqrt/gradients/scalar_mul/parallel_0/mul" - input: "while_loop/while/decoder/block_001/layer_001/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/rsqrt/gradients/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_001/rsqrt/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_001/layer_001/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/rsqrt/gradients/einsum_2/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_001/rsqrt/gradients/einsum_1/parallel_0/Mul" - input: "while_loop/while/decoder/block_001/layer_001/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/rms_norm/reduce_mean/scalar_mul/gradients/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.03125 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/rms_norm/reduce_mean/scalar_mul/gradients/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_001/rsqrt/gradients/einsum_2/parallel_0/Mul" - input: "while_loop/while/decoder/block_001/layer_001/rms_norm/reduce_mean/scalar_mul/gradients/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/zeros/shape_as_tensor" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000 \000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/zeros/Const" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/zeros" - op: "Fill" - input: "while_loop/while/decoder/block_001/layer_001/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/zeros/shape_as_tensor" - input: "while_loop/while/decoder/block_001/layer_001/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_001/layer_001/rms_norm/reduce_mean/scalar_mul/gradients/scalar_mul/parallel_0/mul" - input: "while_loop/while/decoder/block_001/layer_001/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_001/layer_001/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/transpose" - input: "while_loop/while/decoder/block_001/layer_001/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/add" - op: "AddV2" - input: "while_loop/while/decoder/block_001/layer_001/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/zeros" - input: "while_loop/while/decoder/block_001/layer_001/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/rms_norm/square/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_001/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/add" - input: "while_loop/while/decoder/block_001/layer_000/add/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/rms_norm/square/gradients/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 2.0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/rms_norm/square/gradients/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_001/rms_norm/square/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_001/layer_001/rms_norm/square/gradients/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_001/rms_norm/square/gradients/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/decoder/block_001/layer_001/einsum/gradients/add/parallel_0/Add" - input: "while_loop/while/decoder/block_001/layer_001/rms_norm/square/gradients/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/einsum_8/gradients/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_001/layer_001/rms_norm/square/gradients/add/parallel_0/Add" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/einsum_8/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,ed->abce" - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/einsum_8/gradients/einsum_1/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_001/layer_001/rms_norm/square/gradients/add/parallel_0/Add" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/reshape_8/parallel_0/Reshape" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,abce->ed" - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/reshape_8/gradients/reshape/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 5 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 5 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000\002\000\000\000@\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/reshape_8/gradients/reshape/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/einsum_8/gradients/einsum/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/reshape_8/gradients/reshape/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/reshape_7/gradients/reshape/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 5 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 5 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000\002\000\000\000@\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/reshape_7/gradients/reshape/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/reshape_8/gradients/reshape/parallel_0/Reshape" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/reshape_7/gradients/reshape/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/einsum_7/gradients/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/reshape_7/gradients/reshape/parallel_0/Reshape" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/reshape_3/parallel_0/Reshape" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcde,abfde->abcdf" - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/einsum_7/gradients/einsum_1/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/reshape_7/gradients/reshape/parallel_0/Reshape" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/exp/parallel_0/Exp" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcde,abcdf->abfde" - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/exp/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/einsum_7/gradients/einsum/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/exp/parallel_0/Exp" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/add/gradients/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 4 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/add/gradients/reduce/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/exp/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/add/gradients/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/negative/gradients/negative/parallel_0/Neg" - op: "Neg" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/add/gradients/reduce/parallel_0/Sum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/log/gradients/reciprocal/parallel_0/Reciprocal" - op: "Reciprocal" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/parallel_0/Sum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/log/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/negative/gradients/negative/parallel_0/Neg" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/log/gradients/reciprocal/parallel_0/Reciprocal" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/zeros/shape_as_tensor" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 5 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 5 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000\002\000\000\000\000\002\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/zeros/Const" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/zeros" - op: "Fill" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/zeros/shape_as_tensor" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000\003\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/log/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 4 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/transpose" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/add" - op: "AddV2" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/zeros" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/exp/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/add" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/exp/parallel_0/Exp" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/add/gradients/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 4 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/add/gradients/reduce/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/exp/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/add/gradients/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/add/gradients/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/exp/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/exp/gradients/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add_6/gradients/reduce/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 5 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 5 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000\004\000\000\000\003\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add_6/gradients/reduce/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/add/gradients/add/parallel_0/Add" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add_6/gradients/reduce/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/einsum_6/gradients/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/add/gradients/add/parallel_0/Add" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/reshape_2/parallel_0/Reshape" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcde,abedf->abcdf" - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/einsum_6/gradients/einsum_1/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/add/gradients/add/parallel_0/Add" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/reshape/parallel_0/Reshape" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcde,abcdf->abedf" - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add_5/gradients/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 4 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add_5/gradients/reduce/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add_6/gradients/reduce/parallel_0/transpose" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add_5/gradients/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add_5/gradients/reduce_1/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add_5/gradients/reduce_1/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add_6/gradients/reduce/parallel_0/transpose" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add_5/gradients/reduce_1/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add_5/gradients/reduce_1/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\001\000\000\000\000\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add_5/gradients/reduce_1/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add_5/gradients/reduce_1/parallel_0/Sum" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add_5/gradients/reduce_1/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/einsum_5/gradients/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add_5/gradients/reduce_1/parallel_0/transpose" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/einsum_5/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abc,cd->abd" - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/einsum_5/gradients/einsum_1/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/add_5/gradients/reduce_1/parallel_0/transpose" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/one_hot/parallel_0/one_hot" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abc,abd->cd" - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/reshape_3/gradients/reshape/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000\200\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/reshape_3/gradients/reshape/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/einsum_7/gradients/einsum_1/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/reshape_3/gradients/reshape/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/einsum_2/gradients/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/reshape_3/gradients/reshape/parallel_0/Reshape" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/einsum_2/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,ed->abce" - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/einsum_2/gradients/einsum_1/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/reshape_3/gradients/reshape/parallel_0/Reshape" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/reshape_1/parallel_0/Reshape" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,abce->ed" - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/reshape_2/gradients/reshape/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000\200\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/reshape_2/gradients/reshape/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/einsum_6/gradients/einsum_1/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/reshape_2/gradients/reshape/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/einsum_1/gradients/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/reshape_2/gradients/reshape/parallel_0/Reshape" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/einsum_1/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,ed->abce" - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/einsum_1/gradients/einsum_1/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/reshape_2/gradients/reshape/parallel_0/Reshape" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/reshape_1/parallel_0/Reshape" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,abce->ed" - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/einsum_1/gradients/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/einsum_2/gradients/einsum/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/einsum_1/gradients/einsum/parallel_0/einsum/Einsum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/reshape_1/gradients/reshape/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000 \000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/reshape_1/gradients/reshape/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/einsum_1/gradients/add/parallel_0/Add" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/reshape_1/gradients/reshape/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/reshape/gradients/reshape/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000\200\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/reshape/gradients/reshape/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/einsum_6/gradients/einsum/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/reshape/gradients/reshape/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/einsum/gradients/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/reshape/gradients/reshape/parallel_0/Reshape" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/einsum/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,ed->abce" - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/einsum/gradients/einsum_1/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/reshape/gradients/reshape/parallel_0/Reshape" - input: "while_loop/while/decoder/block_001/layer_000/einsum_2/parallel_0/Mul" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,abce->ed" - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/SelfAttention/einsum/gradients/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/reshape_1/gradients/reshape/parallel_0/Reshape" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/einsum/gradients/einsum/parallel_0/einsum/Einsum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/einsum_2/gradients/einsum/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/einsum_2/gradients/einsum/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_001/layer_000/cast/parallel_0/Cast" - input: "while_loop/while/decoder/block_001/layer_000/einsum_2/gradients/einsum/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/einsum_2/gradients/einsum/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/einsum_2/gradients/einsum/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_001/layer_000/einsum_2/gradients/einsum/parallel_0/transpose" - input: "while_loop/while/decoder/block_001/layer_000/einsum_2/gradients/einsum/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/einsum_2/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/einsum/gradients/add/parallel_0/Add" - input: "while_loop/while/decoder/block_001/layer_000/einsum_2/gradients/einsum/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/einsum_2/gradients/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/einsum/gradients/add/parallel_0/Add" - input: "while_loop/while/decoder/block_001/layer_000/einsum_1/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/einsum_2/gradients/einsum_1/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/einsum_2/gradients/einsum_1/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/decoder/block_001/layer_000/einsum_2/gradients/einsum_1/parallel_0/Mul" - input: "while_loop/while/decoder/block_001/layer_000/einsum_2/gradients/einsum_1/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/einsum_1/gradients/einsum/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/einsum_1/gradients/einsum/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_001/layer_000/einsum_1/parallel_0/transpose/Enter" - input: "while_loop/while/decoder/block_001/layer_000/einsum_1/gradients/einsum/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/einsum_1/gradients/einsum/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/einsum_1/gradients/einsum/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_001/layer_000/einsum_1/gradients/einsum/parallel_0/transpose" - input: "while_loop/while/decoder/block_001/layer_000/einsum_1/gradients/einsum/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/einsum_1/gradients/einsum/parallel_0/ExpandDims_1/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/einsum_1/gradients/einsum/parallel_0/ExpandDims_1" - op: "ExpandDims" - input: "while_loop/while/decoder/block_001/layer_000/einsum_1/gradients/einsum/parallel_0/ExpandDims" - input: "while_loop/while/decoder/block_001/layer_000/einsum_1/gradients/einsum/parallel_0/ExpandDims_1/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/einsum_1/gradients/einsum/parallel_0/ExpandDims_2/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 2 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/einsum_1/gradients/einsum/parallel_0/ExpandDims_2" - op: "ExpandDims" - input: "while_loop/while/decoder/block_001/layer_000/einsum_1/gradients/einsum/parallel_0/ExpandDims_1" - input: "while_loop/while/decoder/block_001/layer_000/einsum_1/gradients/einsum/parallel_0/ExpandDims_2/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/einsum_1/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_000/einsum_2/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_001/layer_000/einsum_1/gradients/einsum/parallel_0/ExpandDims_2" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/einsum_1/gradients/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_000/einsum_2/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_001/layer_000/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/einsum_1/gradients/einsum_1/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/einsum_1/gradients/einsum_1/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/decoder/block_001/layer_000/einsum_1/gradients/einsum_1/parallel_0/Mul" - input: "while_loop/while/decoder/block_001/layer_000/einsum_1/gradients/einsum_1/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/einsum/gradients/einsum/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/einsum/gradients/einsum/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_001/layer_000/rsqrt/parallel_0/Rsqrt" - input: "while_loop/while/decoder/block_001/layer_000/einsum/gradients/einsum/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/einsum/gradients/einsum/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/einsum/gradients/einsum/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_001/layer_000/einsum/gradients/einsum/parallel_0/transpose" - input: "while_loop/while/decoder/block_001/layer_000/einsum/gradients/einsum/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/einsum/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_000/einsum_1/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_001/layer_000/einsum/gradients/einsum/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/einsum/gradients/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_000/einsum_1/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_000/layer_002/add/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/einsum/gradients/einsum_1/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/einsum/gradients/einsum_1/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/decoder/block_001/layer_000/einsum/gradients/einsum_1/parallel_0/Mul" - input: "while_loop/while/decoder/block_001/layer_000/einsum/gradients/einsum_1/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/einsum/gradients/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/decoder/block_001/layer_001/rms_norm/square/gradients/add/parallel_0/Add" - input: "while_loop/while/decoder/block_001/layer_000/einsum/gradients/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/rsqrt/gradients/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: -0.5 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/rsqrt/gradients/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_000/einsum/gradients/einsum_1/parallel_0/Sum" - input: "while_loop/while/decoder/block_001/layer_000/rsqrt/gradients/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/rsqrt/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_000/rsqrt/gradients/scalar_mul/parallel_0/mul" - input: "while_loop/while/decoder/block_001/layer_000/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/rsqrt/gradients/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_000/rsqrt/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_001/layer_000/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/rsqrt/gradients/einsum_2/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_000/rsqrt/gradients/einsum_1/parallel_0/Mul" - input: "while_loop/while/decoder/block_001/layer_000/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/rms_norm/reduce_mean/scalar_mul/gradients/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.03125 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/rms_norm/reduce_mean/scalar_mul/gradients/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_000/rsqrt/gradients/einsum_2/parallel_0/Mul" - input: "while_loop/while/decoder/block_001/layer_000/rms_norm/reduce_mean/scalar_mul/gradients/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/zeros/shape_as_tensor" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000 \000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/zeros/Const" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/zeros" - op: "Fill" - input: "while_loop/while/decoder/block_001/layer_000/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/zeros/shape_as_tensor" - input: "while_loop/while/decoder/block_001/layer_000/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_001/layer_000/rms_norm/reduce_mean/scalar_mul/gradients/scalar_mul/parallel_0/mul" - input: "while_loop/while/decoder/block_001/layer_000/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_001/layer_000/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/transpose" - input: "while_loop/while/decoder/block_001/layer_000/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/add" - op: "AddV2" - input: "while_loop/while/decoder/block_001/layer_000/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/zeros" - input: "while_loop/while/decoder/block_001/layer_000/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/rms_norm/square/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_000/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/add" - input: "while_loop/while/decoder/block_000/layer_002/add/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/rms_norm/square/gradients/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 2.0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/rms_norm/square/gradients/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/decoder/block_001/layer_000/rms_norm/square/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_001/layer_000/rms_norm/square/gradients/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_001/layer_000/rms_norm/square/gradients/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/decoder/block_001/layer_000/einsum/gradients/add/parallel_0/Add" - input: "while_loop/while/decoder/block_001/layer_000/rms_norm/square/gradients/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wo/einsum/gradients/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_001/layer_000/rms_norm/square/gradients/add/parallel_0/Add" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wo/einsum/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,ed->abce" - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wo/einsum/gradients/einsum_1/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_001/layer_000/rms_norm/square/gradients/add/parallel_0/Add" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/einsum/parallel_0/Mul" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,abce->ed" - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/einsum/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wo/einsum/gradients/einsum/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_1/einsum/parallel_0/einsum/Einsum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/einsum/gradients/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wo/einsum/gradients/einsum/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/einsum_3/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_1/einsum/gradients/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/einsum/gradients/einsum_1/parallel_0/Mul" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_1/einsum/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,ed->abce" - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_1/einsum/gradients/einsum_1/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/einsum/gradients/einsum_1/parallel_0/Mul" - input: "while_loop/while/decoder/block_000/layer_002/einsum_2/parallel_0/Mul" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,abce->ed" - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/einsum_3/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/einsum/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/scalar_mul_2/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/einsum_3/gradients/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/einsum/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/einsum/parallel_0/einsum/Einsum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/scalar_mul_2/gradients/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.5 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/scalar_mul_2/gradients/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/einsum_3/gradients/einsum_1/parallel_0/Mul" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/scalar_mul_2/gradients/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/tanh/gradients/square/parallel_0/Square" - op: "Square" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/tanh/parallel_0/Tanh" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/tanh/gradients/negative/parallel_0/Neg" - op: "Neg" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/tanh/gradients/square/parallel_0/Square" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/tanh/gradients/add/parallel_0_1/Add" - op: "Add" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/tanh/gradients/add/parallel_0_1/Add/Enter" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/tanh/gradients/negative/parallel_0/Neg" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/tanh/gradients/add/parallel_0_1/Add/Enter" - op: "Enter" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/tanh/gradients/sub/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/tanh/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/tanh/gradients/add/parallel_0_1/Add" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/scalar_mul_2/gradients/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/scalar_mul_1/gradients/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.7978845834732056 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/scalar_mul_1/gradients/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/tanh/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/scalar_mul_1/gradients/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/add/gradients/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/einsum_3/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/scalar_mul_1/gradients/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/einsum_2/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/scalar_mul_1/gradients/scalar_mul/parallel_0/mul" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/einsum/parallel_0/einsum/Einsum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/einsum_2/gradients/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/scalar_mul_1/gradients/scalar_mul/parallel_0/mul" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/einsum_1/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/einsum_2/gradients/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/add/gradients/add/parallel_0/Add" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/einsum_2/gradients/einsum_1/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/einsum_1/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/einsum_2/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/einsum/parallel_0/einsum/Einsum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/einsum_1/gradients/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/einsum_2/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/einsum_1/gradients/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/einsum_2/gradients/add/parallel_0/Add" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/einsum_1/gradients/einsum_1/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/scalar_mul/gradients/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.044714998453855515 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/scalar_mul/gradients/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/einsum_1/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/scalar_mul/gradients/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/scalar_mul/gradients/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/einsum_1/gradients/add/parallel_0/Add" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/scalar_mul/gradients/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/einsum/gradients/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/scalar_mul/gradients/add/parallel_0/Add" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/einsum/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,ed->abce" - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/einsum/gradients/einsum_1/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/scalar_mul/gradients/add/parallel_0/Add" - input: "while_loop/while/decoder/block_000/layer_002/einsum_2/parallel_0/Mul" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,abce->ed" - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/einsum/gradients/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_1/einsum/gradients/einsum/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/einsum/gradients/einsum/parallel_0/einsum/Einsum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/einsum_2/gradients/einsum/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/einsum_2/gradients/einsum/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_000/layer_002/cast/parallel_0/Cast" - input: "while_loop/while/decoder/block_000/layer_002/einsum_2/gradients/einsum/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/einsum_2/gradients/einsum/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/einsum_2/gradients/einsum/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_000/layer_002/einsum_2/gradients/einsum/parallel_0/transpose" - input: "while_loop/while/decoder/block_000/layer_002/einsum_2/gradients/einsum/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/einsum_2/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/einsum/gradients/add/parallel_0/Add" - input: "while_loop/while/decoder/block_000/layer_002/einsum_2/gradients/einsum/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/einsum_2/gradients/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/einsum/gradients/add/parallel_0/Add" - input: "while_loop/while/decoder/block_000/layer_002/einsum_1/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/einsum_2/gradients/einsum_1/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/einsum_2/gradients/einsum_1/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/decoder/block_000/layer_002/einsum_2/gradients/einsum_1/parallel_0/Mul" - input: "while_loop/while/decoder/block_000/layer_002/einsum_2/gradients/einsum_1/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/einsum_1/gradients/einsum/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/einsum_1/gradients/einsum/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_000/layer_002/einsum_1/parallel_0/transpose/Enter" - input: "while_loop/while/decoder/block_000/layer_002/einsum_1/gradients/einsum/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/einsum_1/gradients/einsum/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/einsum_1/gradients/einsum/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_000/layer_002/einsum_1/gradients/einsum/parallel_0/transpose" - input: "while_loop/while/decoder/block_000/layer_002/einsum_1/gradients/einsum/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/einsum_1/gradients/einsum/parallel_0/ExpandDims_1/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/einsum_1/gradients/einsum/parallel_0/ExpandDims_1" - op: "ExpandDims" - input: "while_loop/while/decoder/block_000/layer_002/einsum_1/gradients/einsum/parallel_0/ExpandDims" - input: "while_loop/while/decoder/block_000/layer_002/einsum_1/gradients/einsum/parallel_0/ExpandDims_1/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/einsum_1/gradients/einsum/parallel_0/ExpandDims_2/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 2 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/einsum_1/gradients/einsum/parallel_0/ExpandDims_2" - op: "ExpandDims" - input: "while_loop/while/decoder/block_000/layer_002/einsum_1/gradients/einsum/parallel_0/ExpandDims_1" - input: "while_loop/while/decoder/block_000/layer_002/einsum_1/gradients/einsum/parallel_0/ExpandDims_2/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/einsum_1/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_002/einsum_2/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_000/layer_002/einsum_1/gradients/einsum/parallel_0/ExpandDims_2" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/einsum_1/gradients/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_002/einsum_2/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_000/layer_002/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/einsum_1/gradients/einsum_1/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/einsum_1/gradients/einsum_1/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/decoder/block_000/layer_002/einsum_1/gradients/einsum_1/parallel_0/Mul" - input: "while_loop/while/decoder/block_000/layer_002/einsum_1/gradients/einsum_1/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/einsum/gradients/einsum/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/einsum/gradients/einsum/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_000/layer_002/rsqrt/parallel_0/Rsqrt" - input: "while_loop/while/decoder/block_000/layer_002/einsum/gradients/einsum/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/einsum/gradients/einsum/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/einsum/gradients/einsum/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_000/layer_002/einsum/gradients/einsum/parallel_0/transpose" - input: "while_loop/while/decoder/block_000/layer_002/einsum/gradients/einsum/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/einsum/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_002/einsum_1/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_000/layer_002/einsum/gradients/einsum/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/einsum/gradients/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_002/einsum_1/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_000/layer_001/add/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/einsum/gradients/einsum_1/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/einsum/gradients/einsum_1/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/decoder/block_000/layer_002/einsum/gradients/einsum_1/parallel_0/Mul" - input: "while_loop/while/decoder/block_000/layer_002/einsum/gradients/einsum_1/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/einsum/gradients/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/decoder/block_001/layer_000/rms_norm/square/gradients/add/parallel_0/Add" - input: "while_loop/while/decoder/block_000/layer_002/einsum/gradients/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/rsqrt/gradients/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: -0.5 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/rsqrt/gradients/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_002/einsum/gradients/einsum_1/parallel_0/Sum" - input: "while_loop/while/decoder/block_000/layer_002/rsqrt/gradients/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/rsqrt/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_002/rsqrt/gradients/scalar_mul/parallel_0/mul" - input: "while_loop/while/decoder/block_000/layer_002/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/rsqrt/gradients/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_002/rsqrt/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_000/layer_002/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/rsqrt/gradients/einsum_2/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_002/rsqrt/gradients/einsum_1/parallel_0/Mul" - input: "while_loop/while/decoder/block_000/layer_002/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/rms_norm/reduce_mean/scalar_mul/gradients/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.03125 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/rms_norm/reduce_mean/scalar_mul/gradients/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_002/rsqrt/gradients/einsum_2/parallel_0/Mul" - input: "while_loop/while/decoder/block_000/layer_002/rms_norm/reduce_mean/scalar_mul/gradients/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/zeros/shape_as_tensor" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000 \000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/zeros/Const" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/zeros" - op: "Fill" - input: "while_loop/while/decoder/block_000/layer_002/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/zeros/shape_as_tensor" - input: "while_loop/while/decoder/block_000/layer_002/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_000/layer_002/rms_norm/reduce_mean/scalar_mul/gradients/scalar_mul/parallel_0/mul" - input: "while_loop/while/decoder/block_000/layer_002/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_000/layer_002/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/transpose" - input: "while_loop/while/decoder/block_000/layer_002/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/add" - op: "AddV2" - input: "while_loop/while/decoder/block_000/layer_002/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/zeros" - input: "while_loop/while/decoder/block_000/layer_002/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/rms_norm/square/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_002/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/add" - input: "while_loop/while/decoder/block_000/layer_001/add/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/rms_norm/square/gradients/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 2.0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/rms_norm/square/gradients/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_002/rms_norm/square/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_000/layer_002/rms_norm/square/gradients/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_002/rms_norm/square/gradients/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/decoder/block_000/layer_002/einsum/gradients/add/parallel_0/Add" - input: "while_loop/while/decoder/block_000/layer_002/rms_norm/square/gradients/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/einsum_5/gradients/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_000/layer_002/rms_norm/square/gradients/add/parallel_0/Add" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/einsum_5/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,ed->abce" - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/einsum_5/gradients/einsum_1/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_000/layer_002/rms_norm/square/gradients/add/parallel_0/Add" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/reshape_4/parallel_0/Reshape" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,abce->ed" - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/reshape_4/gradients/reshape/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 5 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 5 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000\002\000\000\000@\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/reshape_4/gradients/reshape/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/einsum_5/gradients/einsum/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/reshape_4/gradients/reshape/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/reshape_3/gradients/reshape/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 5 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 5 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000\002\000\000\000@\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/reshape_3/gradients/reshape/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/reshape_4/gradients/reshape/parallel_0/Reshape" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/reshape_3/gradients/reshape/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/einsum_4/gradients/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/reshape_3/gradients/reshape/parallel_0/Reshape" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/reshape_2/parallel_0/Reshape" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcde,abfde->abcdf" - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/einsum_4/gradients/einsum_1/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/reshape_3/gradients/reshape/parallel_0/Reshape" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/exp/parallel_0/Exp" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcde,abcdf->abfde" - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/exp/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/einsum_4/gradients/einsum/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/exp/parallel_0/Exp" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/add/gradients/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 4 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/add/gradients/reduce/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/exp/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/add/gradients/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/negative/gradients/negative/parallel_0/Neg" - op: "Neg" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/add/gradients/reduce/parallel_0/Sum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/reduce_logsumexp/log/gradients/reciprocal/parallel_0/Reciprocal" - op: "Reciprocal" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/reduce_logsumexp/reduce/parallel_0/Sum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/reduce_logsumexp/log/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/negative/gradients/negative/parallel_0/Neg" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/reduce_logsumexp/log/gradients/reciprocal/parallel_0/Reciprocal" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/zeros/shape_as_tensor" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 5 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 5 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000\002\000\000\000\000\002\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/zeros/Const" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/zeros" - op: "Fill" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/zeros/shape_as_tensor" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000\003\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/reduce_logsumexp/log/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 4 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/transpose" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/add" - op: "AddV2" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/zeros" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/reduce_logsumexp/exp/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/add" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/reduce_logsumexp/exp/parallel_0/Exp" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/reduce_logsumexp/add/gradients/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 4 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/reduce_logsumexp/add/gradients/reduce/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/reduce_logsumexp/exp/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/reduce_logsumexp/add/gradients/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/reduce_logsumexp/add/gradients/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/exp/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/reduce_logsumexp/exp/gradients/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/add/gradients/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/add/gradients/reduce/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/reduce_logsumexp/add/gradients/add/parallel_0/Add" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/add/gradients/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/einsum_3/gradients/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/reduce_logsumexp/add/gradients/add/parallel_0/Add" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/reshape_1/parallel_0/Reshape" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcde,abedf->abcdf" - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/einsum_3/gradients/einsum_1/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/softmax/reduce_logsumexp/add/gradients/add/parallel_0/Add" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/reshape/parallel_0/Reshape" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcde,abcdf->abedf" - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/reshape_2/gradients/reshape/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000\200\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/reshape_2/gradients/reshape/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/einsum_4/gradients/einsum_1/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/reshape_2/gradients/reshape/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/einsum_2/gradients/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/reshape_2/gradients/reshape/parallel_0/Reshape" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/einsum_2/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,ed->abce" - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/einsum_2/gradients/einsum_1/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/reshape_2/gradients/reshape/parallel_0/Reshape" - input: "while_loop/while/reshape_12/parallel_0/Reshape" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,abce->ed" - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/einsum_2/gradients/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/einsum_1/gradients/add/parallel_0/Add" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/einsum_2/gradients/einsum/parallel_0/einsum/Einsum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/reshape_1/gradients/reshape/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000\200\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/reshape_1/gradients/reshape/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/einsum_3/gradients/einsum_1/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/reshape_1/gradients/reshape/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/einsum_1/gradients/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/reshape_1/gradients/reshape/parallel_0/Reshape" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/einsum_1/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,ed->abce" - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/einsum_1/gradients/einsum_1/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/reshape_1/gradients/reshape/parallel_0/Reshape" - input: "while_loop/while/reshape_12/parallel_0/Reshape" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,abce->ed" - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/einsum_1/gradients/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/einsum_2/gradients/add/parallel_0/Add" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/einsum_1/gradients/einsum/parallel_0/einsum/Einsum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/reshape/gradients/reshape/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000\200\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/reshape/gradients/reshape/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/einsum_3/gradients/einsum/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/reshape/gradients/reshape/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/einsum/gradients/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/reshape/gradients/reshape/parallel_0/Reshape" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/einsum/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,ed->abce" - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/einsum/gradients/einsum_1/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/reshape/gradients/reshape/parallel_0/Reshape" - input: "while_loop/while/decoder/block_000/layer_001/einsum_2/parallel_0/Mul" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,abce->ed" - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/einsum_2/gradients/einsum/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/einsum_2/gradients/einsum/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_000/layer_001/cast/parallel_0/Cast" - input: "while_loop/while/decoder/block_000/layer_001/einsum_2/gradients/einsum/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/einsum_2/gradients/einsum/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/einsum_2/gradients/einsum/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_000/layer_001/einsum_2/gradients/einsum/parallel_0/transpose" - input: "while_loop/while/decoder/block_000/layer_001/einsum_2/gradients/einsum/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/einsum_2/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/einsum/gradients/einsum/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/block_000/layer_001/einsum_2/gradients/einsum/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/einsum_2/gradients/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/einsum/gradients/einsum/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/block_000/layer_001/einsum_1/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/einsum_2/gradients/einsum_1/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/einsum_2/gradients/einsum_1/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/decoder/block_000/layer_001/einsum_2/gradients/einsum_1/parallel_0/Mul" - input: "while_loop/while/decoder/block_000/layer_001/einsum_2/gradients/einsum_1/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/einsum_1/gradients/einsum/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/einsum_1/gradients/einsum/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_000/layer_001/einsum_1/parallel_0/transpose/Enter" - input: "while_loop/while/decoder/block_000/layer_001/einsum_1/gradients/einsum/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/einsum_1/gradients/einsum/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/einsum_1/gradients/einsum/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_000/layer_001/einsum_1/gradients/einsum/parallel_0/transpose" - input: "while_loop/while/decoder/block_000/layer_001/einsum_1/gradients/einsum/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/einsum_1/gradients/einsum/parallel_0/ExpandDims_1/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/einsum_1/gradients/einsum/parallel_0/ExpandDims_1" - op: "ExpandDims" - input: "while_loop/while/decoder/block_000/layer_001/einsum_1/gradients/einsum/parallel_0/ExpandDims" - input: "while_loop/while/decoder/block_000/layer_001/einsum_1/gradients/einsum/parallel_0/ExpandDims_1/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/einsum_1/gradients/einsum/parallel_0/ExpandDims_2/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 2 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/einsum_1/gradients/einsum/parallel_0/ExpandDims_2" - op: "ExpandDims" - input: "while_loop/while/decoder/block_000/layer_001/einsum_1/gradients/einsum/parallel_0/ExpandDims_1" - input: "while_loop/while/decoder/block_000/layer_001/einsum_1/gradients/einsum/parallel_0/ExpandDims_2/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/einsum_1/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_001/einsum_2/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_000/layer_001/einsum_1/gradients/einsum/parallel_0/ExpandDims_2" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/einsum_1/gradients/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_001/einsum_2/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_000/layer_001/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/einsum_1/gradients/einsum_1/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/einsum_1/gradients/einsum_1/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/decoder/block_000/layer_001/einsum_1/gradients/einsum_1/parallel_0/Mul" - input: "while_loop/while/decoder/block_000/layer_001/einsum_1/gradients/einsum_1/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/einsum/gradients/einsum/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/einsum/gradients/einsum/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_000/layer_001/rsqrt/parallel_0/Rsqrt" - input: "while_loop/while/decoder/block_000/layer_001/einsum/gradients/einsum/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/einsum/gradients/einsum/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/einsum/gradients/einsum/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_000/layer_001/einsum/gradients/einsum/parallel_0/transpose" - input: "while_loop/while/decoder/block_000/layer_001/einsum/gradients/einsum/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/einsum/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_001/einsum_1/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_000/layer_001/einsum/gradients/einsum/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/einsum/gradients/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_001/einsum_1/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_000/layer_000/add/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/einsum/gradients/einsum_1/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/einsum/gradients/einsum_1/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/decoder/block_000/layer_001/einsum/gradients/einsum_1/parallel_0/Mul" - input: "while_loop/while/decoder/block_000/layer_001/einsum/gradients/einsum_1/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/einsum/gradients/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/decoder/block_000/layer_002/rms_norm/square/gradients/add/parallel_0/Add" - input: "while_loop/while/decoder/block_000/layer_001/einsum/gradients/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/rsqrt/gradients/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: -0.5 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/rsqrt/gradients/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_001/einsum/gradients/einsum_1/parallel_0/Sum" - input: "while_loop/while/decoder/block_000/layer_001/rsqrt/gradients/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/rsqrt/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_001/rsqrt/gradients/scalar_mul/parallel_0/mul" - input: "while_loop/while/decoder/block_000/layer_001/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/rsqrt/gradients/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_001/rsqrt/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_000/layer_001/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/rsqrt/gradients/einsum_2/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_001/rsqrt/gradients/einsum_1/parallel_0/Mul" - input: "while_loop/while/decoder/block_000/layer_001/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/rms_norm/reduce_mean/scalar_mul/gradients/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.03125 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/rms_norm/reduce_mean/scalar_mul/gradients/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_001/rsqrt/gradients/einsum_2/parallel_0/Mul" - input: "while_loop/while/decoder/block_000/layer_001/rms_norm/reduce_mean/scalar_mul/gradients/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/zeros/shape_as_tensor" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000 \000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/zeros/Const" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/zeros" - op: "Fill" - input: "while_loop/while/decoder/block_000/layer_001/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/zeros/shape_as_tensor" - input: "while_loop/while/decoder/block_000/layer_001/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_000/layer_001/rms_norm/reduce_mean/scalar_mul/gradients/scalar_mul/parallel_0/mul" - input: "while_loop/while/decoder/block_000/layer_001/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_000/layer_001/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/transpose" - input: "while_loop/while/decoder/block_000/layer_001/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/add" - op: "AddV2" - input: "while_loop/while/decoder/block_000/layer_001/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/zeros" - input: "while_loop/while/decoder/block_000/layer_001/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/rms_norm/square/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_001/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/add" - input: "while_loop/while/decoder/block_000/layer_000/add/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/rms_norm/square/gradients/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 2.0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/rms_norm/square/gradients/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_001/rms_norm/square/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_000/layer_001/rms_norm/square/gradients/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_001/rms_norm/square/gradients/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/decoder/block_000/layer_001/einsum/gradients/add/parallel_0/Add" - input: "while_loop/while/decoder/block_000/layer_001/rms_norm/square/gradients/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/einsum_8/gradients/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_000/layer_001/rms_norm/square/gradients/add/parallel_0/Add" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/einsum_8/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,ed->abce" - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/einsum_8/gradients/einsum_1/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_000/layer_001/rms_norm/square/gradients/add/parallel_0/Add" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/reshape_8/parallel_0/Reshape" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,abce->ed" - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/reshape_8/gradients/reshape/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 5 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 5 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000\002\000\000\000@\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/reshape_8/gradients/reshape/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/einsum_8/gradients/einsum/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/reshape_8/gradients/reshape/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/reshape_7/gradients/reshape/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 5 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 5 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000\002\000\000\000@\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/reshape_7/gradients/reshape/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/reshape_8/gradients/reshape/parallel_0/Reshape" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/reshape_7/gradients/reshape/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/einsum_7/gradients/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/reshape_7/gradients/reshape/parallel_0/Reshape" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/reshape_3/parallel_0/Reshape" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcde,abfde->abcdf" - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/einsum_7/gradients/einsum_1/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/reshape_7/gradients/reshape/parallel_0/Reshape" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/exp/parallel_0/Exp" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcde,abcdf->abfde" - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/exp/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/einsum_7/gradients/einsum/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/exp/parallel_0/Exp" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/add/gradients/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 4 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/add/gradients/reduce/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/exp/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/add/gradients/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/negative/gradients/negative/parallel_0/Neg" - op: "Neg" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/add/gradients/reduce/parallel_0/Sum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/log/gradients/reciprocal/parallel_0/Reciprocal" - op: "Reciprocal" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/parallel_0/Sum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/log/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/negative/gradients/negative/parallel_0/Neg" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/log/gradients/reciprocal/parallel_0/Reciprocal" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/zeros/shape_as_tensor" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 5 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 5 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000\002\000\000\000\000\002\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/zeros/Const" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/zeros" - op: "Fill" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/zeros/shape_as_tensor" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000\003\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/log/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 4 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/transpose" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/add" - op: "AddV2" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/zeros" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/exp/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/add" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/exp/parallel_0/Exp" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/add/gradients/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 4 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/add/gradients/reduce/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/exp/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/add/gradients/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/add/gradients/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/exp/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/exp/gradients/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add_6/gradients/reduce/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 5 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 5 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000\004\000\000\000\003\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add_6/gradients/reduce/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/add/gradients/add/parallel_0/Add" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add_6/gradients/reduce/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/einsum_6/gradients/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/add/gradients/add/parallel_0/Add" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/reshape_2/parallel_0/Reshape" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcde,abedf->abcdf" - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/einsum_6/gradients/einsum_1/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/add/gradients/add/parallel_0/Add" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/reshape/parallel_0/Reshape" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcde,abcdf->abedf" - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add_5/gradients/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 4 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add_5/gradients/reduce/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add_6/gradients/reduce/parallel_0/transpose" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add_5/gradients/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add_5/gradients/reduce_1/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add_5/gradients/reduce_1/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add_6/gradients/reduce/parallel_0/transpose" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add_5/gradients/reduce_1/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add_5/gradients/reduce_1/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\001\000\000\000\000\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add_5/gradients/reduce_1/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add_5/gradients/reduce_1/parallel_0/Sum" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add_5/gradients/reduce_1/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/einsum_5/gradients/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add_5/gradients/reduce_1/parallel_0/transpose" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/einsum_5/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abc,cd->abd" - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/einsum_5/gradients/einsum_1/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/add_5/gradients/reduce_1/parallel_0/transpose" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/one_hot/parallel_0/one_hot" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abc,abd->cd" - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/einsum_5/gradients/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/einsum_5/gradients/einsum_1/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/einsum_5/gradients/einsum_1/parallel_0/einsum/Einsum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/reshape_3/gradients/reshape/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000\200\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/reshape_3/gradients/reshape/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/einsum_7/gradients/einsum_1/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/reshape_3/gradients/reshape/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/einsum_2/gradients/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/reshape_3/gradients/reshape/parallel_0/Reshape" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/einsum_2/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,ed->abce" - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/einsum_2/gradients/einsum_1/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/reshape_3/gradients/reshape/parallel_0/Reshape" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/reshape_1/parallel_0/Reshape" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,abce->ed" - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/reshape_2/gradients/reshape/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000\200\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/reshape_2/gradients/reshape/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/einsum_6/gradients/einsum_1/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/reshape_2/gradients/reshape/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/einsum_1/gradients/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/reshape_2/gradients/reshape/parallel_0/Reshape" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/einsum_1/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,ed->abce" - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/einsum_1/gradients/einsum_1/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/reshape_2/gradients/reshape/parallel_0/Reshape" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/reshape_1/parallel_0/Reshape" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,abce->ed" - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/einsum_1/gradients/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/einsum_2/gradients/einsum/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/einsum_1/gradients/einsum/parallel_0/einsum/Einsum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/reshape_1/gradients/reshape/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000 \000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/reshape_1/gradients/reshape/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/einsum_1/gradients/add/parallel_0/Add" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/reshape_1/gradients/reshape/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/reshape/gradients/reshape/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000\200\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/reshape/gradients/reshape/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/einsum_6/gradients/einsum/parallel_0/einsum/Einsum" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/reshape/gradients/reshape/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/einsum/gradients/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/reshape/gradients/reshape/parallel_0/Reshape" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/einsum/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,ed->abce" - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/einsum/gradients/einsum_1/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/reshape/gradients/reshape/parallel_0/Reshape" - input: "while_loop/while/decoder/block_000/layer_000/einsum_2/parallel_0/Mul" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,abce->ed" - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/SelfAttention/einsum/gradients/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/reshape_1/gradients/reshape/parallel_0/Reshape" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/einsum/gradients/einsum/parallel_0/einsum/Einsum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/einsum_2/gradients/einsum/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/einsum_2/gradients/einsum/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_000/layer_000/cast/parallel_0/Cast" - input: "while_loop/while/decoder/block_000/layer_000/einsum_2/gradients/einsum/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/einsum_2/gradients/einsum/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/einsum_2/gradients/einsum/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_000/layer_000/einsum_2/gradients/einsum/parallel_0/transpose" - input: "while_loop/while/decoder/block_000/layer_000/einsum_2/gradients/einsum/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/einsum_2/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/einsum/gradients/add/parallel_0/Add" - input: "while_loop/while/decoder/block_000/layer_000/einsum_2/gradients/einsum/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/einsum_2/gradients/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/einsum/gradients/add/parallel_0/Add" - input: "while_loop/while/decoder/block_000/layer_000/einsum_1/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/einsum_2/gradients/einsum_1/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/einsum_2/gradients/einsum_1/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/decoder/block_000/layer_000/einsum_2/gradients/einsum_1/parallel_0/Mul" - input: "while_loop/while/decoder/block_000/layer_000/einsum_2/gradients/einsum_1/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/einsum_1/gradients/einsum/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/einsum_1/gradients/einsum/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_000/layer_000/einsum_1/parallel_0/transpose/Enter" - input: "while_loop/while/decoder/block_000/layer_000/einsum_1/gradients/einsum/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/einsum_1/gradients/einsum/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/einsum_1/gradients/einsum/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_000/layer_000/einsum_1/gradients/einsum/parallel_0/transpose" - input: "while_loop/while/decoder/block_000/layer_000/einsum_1/gradients/einsum/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/einsum_1/gradients/einsum/parallel_0/ExpandDims_1/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/einsum_1/gradients/einsum/parallel_0/ExpandDims_1" - op: "ExpandDims" - input: "while_loop/while/decoder/block_000/layer_000/einsum_1/gradients/einsum/parallel_0/ExpandDims" - input: "while_loop/while/decoder/block_000/layer_000/einsum_1/gradients/einsum/parallel_0/ExpandDims_1/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/einsum_1/gradients/einsum/parallel_0/ExpandDims_2/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 2 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/einsum_1/gradients/einsum/parallel_0/ExpandDims_2" - op: "ExpandDims" - input: "while_loop/while/decoder/block_000/layer_000/einsum_1/gradients/einsum/parallel_0/ExpandDims_1" - input: "while_loop/while/decoder/block_000/layer_000/einsum_1/gradients/einsum/parallel_0/ExpandDims_2/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/einsum_1/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_000/einsum_2/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_000/layer_000/einsum_1/gradients/einsum/parallel_0/ExpandDims_2" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/einsum_1/gradients/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_000/einsum_2/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_000/layer_000/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/einsum_1/gradients/einsum_1/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/einsum_1/gradients/einsum_1/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/decoder/block_000/layer_000/einsum_1/gradients/einsum_1/parallel_0/Mul" - input: "while_loop/while/decoder/block_000/layer_000/einsum_1/gradients/einsum_1/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/einsum/gradients/einsum/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/einsum/gradients/einsum/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_000/layer_000/rsqrt/parallel_0/Rsqrt" - input: "while_loop/while/decoder/block_000/layer_000/einsum/gradients/einsum/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/einsum/gradients/einsum/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/einsum/gradients/einsum/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_000/layer_000/einsum/gradients/einsum/parallel_0/transpose" - input: "while_loop/while/decoder/block_000/layer_000/einsum/gradients/einsum/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/einsum/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_000/einsum_1/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_000/layer_000/einsum/gradients/einsum/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/einsum/gradients/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_000/einsum_1/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/einsum/parallel_0/einsum/Einsum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/einsum/gradients/einsum_1/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/einsum/gradients/einsum_1/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/decoder/block_000/layer_000/einsum/gradients/einsum_1/parallel_0/Mul" - input: "while_loop/while/decoder/block_000/layer_000/einsum/gradients/einsum_1/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/einsum/gradients/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/decoder/block_000/layer_001/rms_norm/square/gradients/add/parallel_0/Add" - input: "while_loop/while/decoder/block_000/layer_000/einsum/gradients/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/rsqrt/gradients/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: -0.5 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/rsqrt/gradients/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_000/einsum/gradients/einsum_1/parallel_0/Sum" - input: "while_loop/while/decoder/block_000/layer_000/rsqrt/gradients/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/rsqrt/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_000/rsqrt/gradients/scalar_mul/parallel_0/mul" - input: "while_loop/while/decoder/block_000/layer_000/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/rsqrt/gradients/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_000/rsqrt/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_000/layer_000/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/rsqrt/gradients/einsum_2/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_000/rsqrt/gradients/einsum_1/parallel_0/Mul" - input: "while_loop/while/decoder/block_000/layer_000/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/rms_norm/reduce_mean/scalar_mul/gradients/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.03125 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/rms_norm/reduce_mean/scalar_mul/gradients/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_000/rsqrt/gradients/einsum_2/parallel_0/Mul" - input: "while_loop/while/decoder/block_000/layer_000/rms_norm/reduce_mean/scalar_mul/gradients/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/zeros/shape_as_tensor" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000 \000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/zeros/Const" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/zeros" - op: "Fill" - input: "while_loop/while/decoder/block_000/layer_000/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/zeros/shape_as_tensor" - input: "while_loop/while/decoder/block_000/layer_000/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/decoder/block_000/layer_000/rms_norm/reduce_mean/scalar_mul/gradients/scalar_mul/parallel_0/mul" - input: "while_loop/while/decoder/block_000/layer_000/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/decoder/block_000/layer_000/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/transpose" - input: "while_loop/while/decoder/block_000/layer_000/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/add" - op: "AddV2" - input: "while_loop/while/decoder/block_000/layer_000/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/zeros" - input: "while_loop/while/decoder/block_000/layer_000/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/rms_norm/square/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_000/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/add" - input: "while_loop/while/decoder/einsum/parallel_0/einsum/Einsum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/rms_norm/square/gradients/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 2.0 - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/rms_norm/square/gradients/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/decoder/block_000/layer_000/rms_norm/square/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/decoder/block_000/layer_000/rms_norm/square/gradients/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/block_000/layer_000/rms_norm/square/gradients/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/decoder/block_000/layer_000/einsum/gradients/add/parallel_0/Add" - input: "while_loop/while/decoder/block_000/layer_000/rms_norm/square/gradients/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/decoder/einsum/gradients/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_000/layer_000/rms_norm/square/gradients/add/parallel_0/Add" - input: "while_loop/while/encoder/einsum/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32128 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,ed->abce" - } - } -} -node { - name: "while_loop/while/decoder/einsum/gradients/einsum_1/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/decoder/block_000/layer_000/rms_norm/square/gradients/add/parallel_0/Add" - input: "while_loop/while/decoder/one_hot/parallel_0/one_hot" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,abce->ed" - } - } -} -node { - name: "while_loop/while/reshape_12/gradients/reshape/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000 \000\000\000" - } - } - } -} -node { - name: "while_loop/while/reshape_12/gradients/reshape/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/einsum_1/gradients/add/parallel_0/Add" - input: "while_loop/while/reshape_12/gradients/reshape/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/einsum_3/gradients/einsum/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/einsum_3/gradients/einsum/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/encoder/cast/parallel_0/Cast" - input: "while_loop/while/encoder/einsum_3/gradients/einsum/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/einsum_3/gradients/einsum/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/encoder/einsum_3/gradients/einsum/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/encoder/einsum_3/gradients/einsum/parallel_0/transpose" - input: "while_loop/while/encoder/einsum_3/gradients/einsum/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/einsum_3/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/reshape_12/gradients/reshape/parallel_0/Reshape" - input: "while_loop/while/encoder/einsum_3/gradients/einsum/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/einsum_3/gradients/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/reshape_12/gradients/reshape/parallel_0/Reshape" - input: "while_loop/while/encoder/einsum_2/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/einsum_3/gradients/einsum_1/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/encoder/einsum_3/gradients/einsum_1/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/encoder/einsum_3/gradients/einsum_1/parallel_0/Mul" - input: "while_loop/while/encoder/einsum_3/gradients/einsum_1/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/encoder/einsum_2/gradients/einsum/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/encoder/einsum_2/gradients/einsum/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/encoder/einsum_2/parallel_0/transpose/Enter" - input: "while_loop/while/encoder/einsum_2/gradients/einsum/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/einsum_2/gradients/einsum/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/encoder/einsum_2/gradients/einsum/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/encoder/einsum_2/gradients/einsum/parallel_0/transpose" - input: "while_loop/while/encoder/einsum_2/gradients/einsum/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/einsum_2/gradients/einsum/parallel_0/ExpandDims_1/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "while_loop/while/encoder/einsum_2/gradients/einsum/parallel_0/ExpandDims_1" - op: "ExpandDims" - input: "while_loop/while/encoder/einsum_2/gradients/einsum/parallel_0/ExpandDims" - input: "while_loop/while/encoder/einsum_2/gradients/einsum/parallel_0/ExpandDims_1/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/einsum_2/gradients/einsum/parallel_0/ExpandDims_2/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 2 - } - } - } -} -node { - name: "while_loop/while/encoder/einsum_2/gradients/einsum/parallel_0/ExpandDims_2" - op: "ExpandDims" - input: "while_loop/while/encoder/einsum_2/gradients/einsum/parallel_0/ExpandDims_1" - input: "while_loop/while/encoder/einsum_2/gradients/einsum/parallel_0/ExpandDims_2/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/einsum_2/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/einsum_3/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/encoder/einsum_2/gradients/einsum/parallel_0/ExpandDims_2" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/einsum_2/gradients/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/einsum_3/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/encoder/einsum_1/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/einsum_2/gradients/einsum_1/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/einsum_2/gradients/einsum_1/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/encoder/einsum_2/gradients/einsum_1/parallel_0/Mul" - input: "while_loop/while/encoder/einsum_2/gradients/einsum_1/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/encoder/einsum_1/gradients/einsum/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/einsum_1/gradients/einsum/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/encoder/rsqrt/parallel_0/Rsqrt" - input: "while_loop/while/encoder/einsum_1/gradients/einsum/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/einsum_1/gradients/einsum/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/encoder/einsum_1/gradients/einsum/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/encoder/einsum_1/gradients/einsum/parallel_0/transpose" - input: "while_loop/while/encoder/einsum_1/gradients/einsum/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/einsum_1/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/einsum_2/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/encoder/einsum_1/gradients/einsum/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/einsum_1/gradients/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/einsum_2/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/encoder/block_001/layer_001/add/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/einsum_1/gradients/einsum_1/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/encoder/einsum_1/gradients/einsum_1/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/encoder/einsum_1/gradients/einsum_1/parallel_0/Mul" - input: "while_loop/while/encoder/einsum_1/gradients/einsum_1/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/encoder/rsqrt/gradients/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: -0.5 - } - } - } -} -node { - name: "while_loop/while/encoder/rsqrt/gradients/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/encoder/einsum_1/gradients/einsum_1/parallel_0/Sum" - input: "while_loop/while/encoder/rsqrt/gradients/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/rsqrt/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/rsqrt/gradients/scalar_mul/parallel_0/mul" - input: "while_loop/while/encoder/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/rsqrt/gradients/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/rsqrt/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/encoder/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/rsqrt/gradients/einsum_2/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/rsqrt/gradients/einsum_1/parallel_0/Mul" - input: "while_loop/while/encoder/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/rms_norm/reduce_mean/scalar_mul/gradients/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.03125 - } - } - } -} -node { - name: "while_loop/while/encoder/rms_norm/reduce_mean/scalar_mul/gradients/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/encoder/rsqrt/gradients/einsum_2/parallel_0/Mul" - input: "while_loop/while/encoder/rms_norm/reduce_mean/scalar_mul/gradients/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/zeros/shape_as_tensor" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000 \000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/zeros/Const" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/while/encoder/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/zeros" - op: "Fill" - input: "while_loop/while/encoder/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/zeros/shape_as_tensor" - input: "while_loop/while/encoder/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "while_loop/while/encoder/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/encoder/rms_norm/reduce_mean/scalar_mul/gradients/scalar_mul/parallel_0/mul" - input: "while_loop/while/encoder/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/encoder/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/encoder/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/transpose" - input: "while_loop/while/encoder/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/add" - op: "AddV2" - input: "while_loop/while/encoder/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/zeros" - input: "while_loop/while/encoder/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/rms_norm/square/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/add" - input: "while_loop/while/encoder/block_001/layer_001/add/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/rms_norm/square/gradients/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 2.0 - } - } - } -} -node { - name: "while_loop/while/encoder/rms_norm/square/gradients/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/encoder/rms_norm/square/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/encoder/rms_norm/square/gradients/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/rms_norm/square/gradients/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/encoder/einsum_1/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/encoder/rms_norm/square/gradients/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wo/einsum/gradients/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/encoder/rms_norm/square/gradients/add/parallel_0/Add" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wo/einsum/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,ed->abce" - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wo/einsum/gradients/einsum_1/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/encoder/rms_norm/square/gradients/add/parallel_0/Add" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/einsum/parallel_0/Mul" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,abce->ed" - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/einsum/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wo/einsum/gradients/einsum/parallel_0/einsum/Einsum" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_1/einsum/parallel_0/einsum/Einsum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/einsum/gradients/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wo/einsum/gradients/einsum/parallel_0/einsum/Einsum" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/einsum_3/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_1/einsum/gradients/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/einsum/gradients/einsum_1/parallel_0/Mul" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_1/einsum/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,ed->abce" - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_1/einsum/gradients/einsum_1/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/einsum/gradients/einsum_1/parallel_0/Mul" - input: "while_loop/while/encoder/block_001/layer_001/einsum_2/parallel_0/Mul" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,abce->ed" - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/einsum_3/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/einsum/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/scalar_mul_2/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/einsum_3/gradients/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/einsum/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/einsum/parallel_0/einsum/Einsum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/scalar_mul_2/gradients/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.5 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/scalar_mul_2/gradients/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/einsum_3/gradients/einsum_1/parallel_0/Mul" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/scalar_mul_2/gradients/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/tanh/gradients/square/parallel_0/Square" - op: "Square" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/tanh/parallel_0/Tanh" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/tanh/gradients/negative/parallel_0/Neg" - op: "Neg" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/tanh/gradients/square/parallel_0/Square" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/tanh/gradients/add/parallel_0_1/Add" - op: "Add" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/tanh/gradients/add/parallel_0_1/Add/Enter" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/tanh/gradients/negative/parallel_0/Neg" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/tanh/gradients/add/parallel_0_1/Add/Enter" - op: "Enter" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/tanh/gradients/sub/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/tanh/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/tanh/gradients/add/parallel_0_1/Add" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/scalar_mul_2/gradients/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/scalar_mul_1/gradients/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.7978845834732056 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/scalar_mul_1/gradients/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/tanh/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/scalar_mul_1/gradients/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/add/gradients/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/einsum_3/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/scalar_mul_1/gradients/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/einsum_2/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/scalar_mul_1/gradients/scalar_mul/parallel_0/mul" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/einsum/parallel_0/einsum/Einsum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/einsum_2/gradients/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/scalar_mul_1/gradients/scalar_mul/parallel_0/mul" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/einsum_1/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/einsum_2/gradients/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/add/gradients/add/parallel_0/Add" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/einsum_2/gradients/einsum_1/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/einsum_1/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/einsum_2/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/einsum/parallel_0/einsum/Einsum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/einsum_1/gradients/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/einsum_2/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/einsum_1/gradients/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/einsum_2/gradients/add/parallel_0/Add" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/einsum_1/gradients/einsum_1/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/scalar_mul/gradients/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.044714998453855515 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/scalar_mul/gradients/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/einsum_1/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/scalar_mul/gradients/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/scalar_mul/gradients/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/einsum_1/gradients/add/parallel_0/Add" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/scalar_mul/gradients/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/einsum/gradients/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/scalar_mul/gradients/add/parallel_0/Add" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/einsum/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,ed->abce" - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/einsum/gradients/einsum_1/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/scalar_mul/gradients/add/parallel_0/Add" - input: "while_loop/while/encoder/block_001/layer_001/einsum_2/parallel_0/Mul" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,abce->ed" - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/einsum/gradients/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_1/einsum/gradients/einsum/parallel_0/einsum/Einsum" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/einsum/gradients/einsum/parallel_0/einsum/Einsum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/einsum_2/gradients/einsum/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/einsum_2/gradients/einsum/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/encoder/block_001/layer_001/cast/parallel_0/Cast" - input: "while_loop/while/encoder/block_001/layer_001/einsum_2/gradients/einsum/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/einsum_2/gradients/einsum/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/einsum_2/gradients/einsum/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/encoder/block_001/layer_001/einsum_2/gradients/einsum/parallel_0/transpose" - input: "while_loop/while/encoder/block_001/layer_001/einsum_2/gradients/einsum/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/einsum_2/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/einsum/gradients/add/parallel_0/Add" - input: "while_loop/while/encoder/block_001/layer_001/einsum_2/gradients/einsum/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/einsum_2/gradients/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/einsum/gradients/add/parallel_0/Add" - input: "while_loop/while/encoder/block_001/layer_001/einsum_1/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/einsum_2/gradients/einsum_1/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/einsum_2/gradients/einsum_1/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/encoder/block_001/layer_001/einsum_2/gradients/einsum_1/parallel_0/Mul" - input: "while_loop/while/encoder/block_001/layer_001/einsum_2/gradients/einsum_1/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/einsum_1/gradients/einsum/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/einsum_1/gradients/einsum/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/encoder/block_001/layer_001/einsum_1/parallel_0/transpose/Enter" - input: "while_loop/while/encoder/block_001/layer_001/einsum_1/gradients/einsum/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/einsum_1/gradients/einsum/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/einsum_1/gradients/einsum/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/encoder/block_001/layer_001/einsum_1/gradients/einsum/parallel_0/transpose" - input: "while_loop/while/encoder/block_001/layer_001/einsum_1/gradients/einsum/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/einsum_1/gradients/einsum/parallel_0/ExpandDims_1/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/einsum_1/gradients/einsum/parallel_0/ExpandDims_1" - op: "ExpandDims" - input: "while_loop/while/encoder/block_001/layer_001/einsum_1/gradients/einsum/parallel_0/ExpandDims" - input: "while_loop/while/encoder/block_001/layer_001/einsum_1/gradients/einsum/parallel_0/ExpandDims_1/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/einsum_1/gradients/einsum/parallel_0/ExpandDims_2/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 2 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/einsum_1/gradients/einsum/parallel_0/ExpandDims_2" - op: "ExpandDims" - input: "while_loop/while/encoder/block_001/layer_001/einsum_1/gradients/einsum/parallel_0/ExpandDims_1" - input: "while_loop/while/encoder/block_001/layer_001/einsum_1/gradients/einsum/parallel_0/ExpandDims_2/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/einsum_1/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_001/layer_001/einsum_2/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/encoder/block_001/layer_001/einsum_1/gradients/einsum/parallel_0/ExpandDims_2" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/einsum_1/gradients/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_001/layer_001/einsum_2/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/encoder/block_001/layer_001/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/einsum_1/gradients/einsum_1/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/einsum_1/gradients/einsum_1/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/encoder/block_001/layer_001/einsum_1/gradients/einsum_1/parallel_0/Mul" - input: "while_loop/while/encoder/block_001/layer_001/einsum_1/gradients/einsum_1/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/einsum/gradients/einsum/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/einsum/gradients/einsum/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/encoder/block_001/layer_001/rsqrt/parallel_0/Rsqrt" - input: "while_loop/while/encoder/block_001/layer_001/einsum/gradients/einsum/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/einsum/gradients/einsum/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/einsum/gradients/einsum/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/encoder/block_001/layer_001/einsum/gradients/einsum/parallel_0/transpose" - input: "while_loop/while/encoder/block_001/layer_001/einsum/gradients/einsum/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/einsum/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_001/layer_001/einsum_1/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/encoder/block_001/layer_001/einsum/gradients/einsum/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/einsum/gradients/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_001/layer_001/einsum_1/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/encoder/block_001/layer_000/add/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/einsum/gradients/einsum_1/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/einsum/gradients/einsum_1/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/encoder/block_001/layer_001/einsum/gradients/einsum_1/parallel_0/Mul" - input: "while_loop/while/encoder/block_001/layer_001/einsum/gradients/einsum_1/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/einsum/gradients/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/encoder/rms_norm/square/gradients/add/parallel_0/Add" - input: "while_loop/while/encoder/block_001/layer_001/einsum/gradients/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/rsqrt/gradients/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: -0.5 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/rsqrt/gradients/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/encoder/block_001/layer_001/einsum/gradients/einsum_1/parallel_0/Sum" - input: "while_loop/while/encoder/block_001/layer_001/rsqrt/gradients/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/rsqrt/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_001/layer_001/rsqrt/gradients/scalar_mul/parallel_0/mul" - input: "while_loop/while/encoder/block_001/layer_001/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/rsqrt/gradients/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_001/layer_001/rsqrt/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/encoder/block_001/layer_001/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/rsqrt/gradients/einsum_2/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_001/layer_001/rsqrt/gradients/einsum_1/parallel_0/Mul" - input: "while_loop/while/encoder/block_001/layer_001/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/rms_norm/reduce_mean/scalar_mul/gradients/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.03125 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/rms_norm/reduce_mean/scalar_mul/gradients/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/encoder/block_001/layer_001/rsqrt/gradients/einsum_2/parallel_0/Mul" - input: "while_loop/while/encoder/block_001/layer_001/rms_norm/reduce_mean/scalar_mul/gradients/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/zeros/shape_as_tensor" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000 \000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/zeros/Const" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/zeros" - op: "Fill" - input: "while_loop/while/encoder/block_001/layer_001/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/zeros/shape_as_tensor" - input: "while_loop/while/encoder/block_001/layer_001/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/encoder/block_001/layer_001/rms_norm/reduce_mean/scalar_mul/gradients/scalar_mul/parallel_0/mul" - input: "while_loop/while/encoder/block_001/layer_001/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/encoder/block_001/layer_001/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/transpose" - input: "while_loop/while/encoder/block_001/layer_001/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/add" - op: "AddV2" - input: "while_loop/while/encoder/block_001/layer_001/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/zeros" - input: "while_loop/while/encoder/block_001/layer_001/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/rms_norm/square/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_001/layer_001/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/add" - input: "while_loop/while/encoder/block_001/layer_000/add/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/rms_norm/square/gradients/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 2.0 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/rms_norm/square/gradients/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/encoder/block_001/layer_001/rms_norm/square/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/encoder/block_001/layer_001/rms_norm/square/gradients/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_001/rms_norm/square/gradients/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/encoder/block_001/layer_001/einsum/gradients/add/parallel_0/Add" - input: "while_loop/while/encoder/block_001/layer_001/rms_norm/square/gradients/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/einsum_9/gradients/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/encoder/block_001/layer_001/rms_norm/square/gradients/add/parallel_0/Add" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/einsum_9/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,ed->abce" - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/einsum_9/gradients/einsum_1/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/encoder/block_001/layer_001/rms_norm/square/gradients/add/parallel_0/Add" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/reshape_7/parallel_0/Reshape" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,abce->ed" - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/reshape_7/gradients/reshape/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 5 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 5 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000\002\000\000\000@\000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/reshape_7/gradients/reshape/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/einsum_9/gradients/einsum/parallel_0/einsum/Einsum" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/reshape_7/gradients/reshape/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/reshape_6/gradients/reshape/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 5 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 5 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000\002\000\000\000@\000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/reshape_6/gradients/reshape/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/reshape_7/gradients/reshape/parallel_0/Reshape" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/reshape_6/gradients/reshape/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/einsum_8/gradients/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/reshape_6/gradients/reshape/parallel_0/Reshape" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/reshape_3/parallel_0/Reshape" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcde,abfde->abcdf" - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/einsum_8/gradients/einsum_1/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/reshape_6/gradients/reshape/parallel_0/Reshape" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/exp/parallel_0/Exp" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcde,abcdf->abfde" - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/exp/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/einsum_8/gradients/einsum/parallel_0/einsum/Einsum" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/exp/parallel_0/Exp" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/add/gradients/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 4 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/add/gradients/reduce/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/exp/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/add/gradients/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/negative/gradients/negative/parallel_0/Neg" - op: "Neg" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/add/gradients/reduce/parallel_0/Sum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/log/gradients/reciprocal/parallel_0/Reciprocal" - op: "Reciprocal" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/parallel_0/Sum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/log/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/negative/gradients/negative/parallel_0/Neg" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/log/gradients/reciprocal/parallel_0/Reciprocal" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/zeros/shape_as_tensor" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 5 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 5 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000\002\000\000\000\000\002\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/zeros/Const" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/zeros" - op: "Fill" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/zeros/shape_as_tensor" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000\003\000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/log/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 4 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/transpose" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/add" - op: "AddV2" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/zeros" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/exp/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/add" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/exp/parallel_0/Exp" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/add/gradients/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 4 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/add/gradients/reduce/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/exp/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/add/gradients/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/add/gradients/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/exp/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/exp/gradients/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/add_5/gradients/reduce/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 5 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 5 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000\004\000\000\000\003\000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/add_5/gradients/reduce/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/add/gradients/add/parallel_0/Add" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/add_5/gradients/reduce/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/einsum_7/gradients/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/add/gradients/add/parallel_0/Add" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/reshape_2/parallel_0/Reshape" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcde,abedf->abcdf" - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/einsum_7/gradients/einsum_1/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/softmax/reduce_logsumexp/add/gradients/add/parallel_0/Add" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/reshape/parallel_0/Reshape" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcde,abcdf->abedf" - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/add_4/gradients/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 4 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/add_4/gradients/reduce/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/add_5/gradients/reduce/parallel_0/transpose" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/add_4/gradients/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/add_4/gradients/reduce_1/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/add_4/gradients/reduce_1/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/add_5/gradients/reduce/parallel_0/transpose" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/add_4/gradients/reduce_1/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/add_4/gradients/reduce_1/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\001\000\000\000\000\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/add_4/gradients/reduce_1/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/add_4/gradients/reduce_1/parallel_0/Sum" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/add_4/gradients/reduce_1/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/einsum_6/gradients/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/add_4/gradients/reduce_1/parallel_0/transpose" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/einsum_6/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abc,cd->abd" - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/einsum_6/gradients/einsum_1/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/add_4/gradients/reduce_1/parallel_0/transpose" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/one_hot/parallel_0/one_hot" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abc,abd->cd" - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/reshape_3/gradients/reshape/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000\200\000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/reshape_3/gradients/reshape/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/einsum_8/gradients/einsum_1/parallel_0/einsum/Einsum" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/reshape_3/gradients/reshape/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/einsum_2/gradients/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/reshape_3/gradients/reshape/parallel_0/Reshape" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/einsum_2/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,ed->abce" - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/einsum_2/gradients/einsum_1/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/reshape_3/gradients/reshape/parallel_0/Reshape" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/reshape_1/parallel_0/Reshape" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,abce->ed" - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/reshape_2/gradients/reshape/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000\200\000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/reshape_2/gradients/reshape/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/einsum_7/gradients/einsum_1/parallel_0/einsum/Einsum" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/reshape_2/gradients/reshape/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/einsum_1/gradients/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/reshape_2/gradients/reshape/parallel_0/Reshape" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/einsum_1/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,ed->abce" - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/einsum_1/gradients/einsum_1/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/reshape_2/gradients/reshape/parallel_0/Reshape" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/reshape_1/parallel_0/Reshape" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,abce->ed" - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/einsum_1/gradients/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/einsum_2/gradients/einsum/parallel_0/einsum/Einsum" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/einsum_1/gradients/einsum/parallel_0/einsum/Einsum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/reshape_1/gradients/reshape/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000 \000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/reshape_1/gradients/reshape/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/einsum_1/gradients/add/parallel_0/Add" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/reshape_1/gradients/reshape/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/reshape/gradients/reshape/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000\200\000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/reshape/gradients/reshape/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/einsum_7/gradients/einsum/parallel_0/einsum/Einsum" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/reshape/gradients/reshape/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/einsum/gradients/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/reshape/gradients/reshape/parallel_0/Reshape" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/einsum/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,ed->abce" - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/einsum/gradients/einsum_1/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/reshape/gradients/reshape/parallel_0/Reshape" - input: "while_loop/while/encoder/block_001/layer_000/einsum_2/parallel_0/Mul" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,abce->ed" - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/SelfAttention/einsum/gradients/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/reshape_1/gradients/reshape/parallel_0/Reshape" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/einsum/gradients/einsum/parallel_0/einsum/Einsum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/einsum_2/gradients/einsum/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/einsum_2/gradients/einsum/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/encoder/block_001/layer_000/cast/parallel_0/Cast" - input: "while_loop/while/encoder/block_001/layer_000/einsum_2/gradients/einsum/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/einsum_2/gradients/einsum/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/einsum_2/gradients/einsum/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/encoder/block_001/layer_000/einsum_2/gradients/einsum/parallel_0/transpose" - input: "while_loop/while/encoder/block_001/layer_000/einsum_2/gradients/einsum/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/einsum_2/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/einsum/gradients/add/parallel_0/Add" - input: "while_loop/while/encoder/block_001/layer_000/einsum_2/gradients/einsum/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/einsum_2/gradients/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/einsum/gradients/add/parallel_0/Add" - input: "while_loop/while/encoder/block_001/layer_000/einsum_1/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/einsum_2/gradients/einsum_1/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/einsum_2/gradients/einsum_1/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/encoder/block_001/layer_000/einsum_2/gradients/einsum_1/parallel_0/Mul" - input: "while_loop/while/encoder/block_001/layer_000/einsum_2/gradients/einsum_1/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/einsum_1/gradients/einsum/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/einsum_1/gradients/einsum/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/encoder/block_001/layer_000/einsum_1/parallel_0/transpose/Enter" - input: "while_loop/while/encoder/block_001/layer_000/einsum_1/gradients/einsum/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/einsum_1/gradients/einsum/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/einsum_1/gradients/einsum/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/encoder/block_001/layer_000/einsum_1/gradients/einsum/parallel_0/transpose" - input: "while_loop/while/encoder/block_001/layer_000/einsum_1/gradients/einsum/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/einsum_1/gradients/einsum/parallel_0/ExpandDims_1/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/einsum_1/gradients/einsum/parallel_0/ExpandDims_1" - op: "ExpandDims" - input: "while_loop/while/encoder/block_001/layer_000/einsum_1/gradients/einsum/parallel_0/ExpandDims" - input: "while_loop/while/encoder/block_001/layer_000/einsum_1/gradients/einsum/parallel_0/ExpandDims_1/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/einsum_1/gradients/einsum/parallel_0/ExpandDims_2/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 2 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/einsum_1/gradients/einsum/parallel_0/ExpandDims_2" - op: "ExpandDims" - input: "while_loop/while/encoder/block_001/layer_000/einsum_1/gradients/einsum/parallel_0/ExpandDims_1" - input: "while_loop/while/encoder/block_001/layer_000/einsum_1/gradients/einsum/parallel_0/ExpandDims_2/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/einsum_1/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_001/layer_000/einsum_2/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/encoder/block_001/layer_000/einsum_1/gradients/einsum/parallel_0/ExpandDims_2" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/einsum_1/gradients/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_001/layer_000/einsum_2/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/encoder/block_001/layer_000/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/einsum_1/gradients/einsum_1/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/einsum_1/gradients/einsum_1/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/encoder/block_001/layer_000/einsum_1/gradients/einsum_1/parallel_0/Mul" - input: "while_loop/while/encoder/block_001/layer_000/einsum_1/gradients/einsum_1/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/einsum/gradients/einsum/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/einsum/gradients/einsum/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/encoder/block_001/layer_000/rsqrt/parallel_0/Rsqrt" - input: "while_loop/while/encoder/block_001/layer_000/einsum/gradients/einsum/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/einsum/gradients/einsum/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/einsum/gradients/einsum/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/encoder/block_001/layer_000/einsum/gradients/einsum/parallel_0/transpose" - input: "while_loop/while/encoder/block_001/layer_000/einsum/gradients/einsum/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/einsum/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_001/layer_000/einsum_1/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/encoder/block_001/layer_000/einsum/gradients/einsum/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/einsum/gradients/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_001/layer_000/einsum_1/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/encoder/block_000/layer_001/add/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/einsum/gradients/einsum_1/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/einsum/gradients/einsum_1/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/encoder/block_001/layer_000/einsum/gradients/einsum_1/parallel_0/Mul" - input: "while_loop/while/encoder/block_001/layer_000/einsum/gradients/einsum_1/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/einsum/gradients/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/encoder/block_001/layer_001/rms_norm/square/gradients/add/parallel_0/Add" - input: "while_loop/while/encoder/block_001/layer_000/einsum/gradients/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/rsqrt/gradients/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: -0.5 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/rsqrt/gradients/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/encoder/block_001/layer_000/einsum/gradients/einsum_1/parallel_0/Sum" - input: "while_loop/while/encoder/block_001/layer_000/rsqrt/gradients/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/rsqrt/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_001/layer_000/rsqrt/gradients/scalar_mul/parallel_0/mul" - input: "while_loop/while/encoder/block_001/layer_000/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/rsqrt/gradients/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_001/layer_000/rsqrt/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/encoder/block_001/layer_000/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/rsqrt/gradients/einsum_2/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_001/layer_000/rsqrt/gradients/einsum_1/parallel_0/Mul" - input: "while_loop/while/encoder/block_001/layer_000/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/rms_norm/reduce_mean/scalar_mul/gradients/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.03125 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/rms_norm/reduce_mean/scalar_mul/gradients/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/encoder/block_001/layer_000/rsqrt/gradients/einsum_2/parallel_0/Mul" - input: "while_loop/while/encoder/block_001/layer_000/rms_norm/reduce_mean/scalar_mul/gradients/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/zeros/shape_as_tensor" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000 \000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/zeros/Const" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/zeros" - op: "Fill" - input: "while_loop/while/encoder/block_001/layer_000/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/zeros/shape_as_tensor" - input: "while_loop/while/encoder/block_001/layer_000/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/encoder/block_001/layer_000/rms_norm/reduce_mean/scalar_mul/gradients/scalar_mul/parallel_0/mul" - input: "while_loop/while/encoder/block_001/layer_000/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/encoder/block_001/layer_000/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/transpose" - input: "while_loop/while/encoder/block_001/layer_000/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/add" - op: "AddV2" - input: "while_loop/while/encoder/block_001/layer_000/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/zeros" - input: "while_loop/while/encoder/block_001/layer_000/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/rms_norm/square/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_001/layer_000/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/add" - input: "while_loop/while/encoder/block_000/layer_001/add/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/rms_norm/square/gradients/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 2.0 - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/rms_norm/square/gradients/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/encoder/block_001/layer_000/rms_norm/square/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/encoder/block_001/layer_000/rms_norm/square/gradients/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_001/layer_000/rms_norm/square/gradients/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/encoder/block_001/layer_000/einsum/gradients/add/parallel_0/Add" - input: "while_loop/while/encoder/block_001/layer_000/rms_norm/square/gradients/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wo/einsum/gradients/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/encoder/block_001/layer_000/rms_norm/square/gradients/add/parallel_0/Add" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wo/einsum/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,ed->abce" - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wo/einsum/gradients/einsum_1/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/encoder/block_001/layer_000/rms_norm/square/gradients/add/parallel_0/Add" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/einsum/parallel_0/Mul" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,abce->ed" - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/einsum/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wo/einsum/gradients/einsum/parallel_0/einsum/Einsum" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_1/einsum/parallel_0/einsum/Einsum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/einsum/gradients/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wo/einsum/gradients/einsum/parallel_0/einsum/Einsum" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/einsum_3/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_1/einsum/gradients/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/einsum/gradients/einsum_1/parallel_0/Mul" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_1/einsum/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,ed->abce" - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_1/einsum/gradients/einsum_1/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/einsum/gradients/einsum_1/parallel_0/Mul" - input: "while_loop/while/encoder/block_000/layer_001/einsum_2/parallel_0/Mul" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,abce->ed" - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/einsum_3/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/einsum/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/scalar_mul_2/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/einsum_3/gradients/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/einsum/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/einsum/parallel_0/einsum/Einsum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/scalar_mul_2/gradients/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.5 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/scalar_mul_2/gradients/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/einsum_3/gradients/einsum_1/parallel_0/Mul" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/scalar_mul_2/gradients/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/tanh/gradients/square/parallel_0/Square" - op: "Square" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/tanh/parallel_0/Tanh" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/tanh/gradients/negative/parallel_0/Neg" - op: "Neg" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/tanh/gradients/square/parallel_0/Square" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/tanh/gradients/add/parallel_0_1/Add" - op: "Add" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/tanh/gradients/add/parallel_0_1/Add/Enter" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/tanh/gradients/negative/parallel_0/Neg" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/tanh/gradients/add/parallel_0_1/Add/Enter" - op: "Enter" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/tanh/gradients/sub/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "frame_name" - value { - s: "while_loop/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/tanh/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/tanh/gradients/add/parallel_0_1/Add" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/scalar_mul_2/gradients/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/scalar_mul_1/gradients/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.7978845834732056 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/scalar_mul_1/gradients/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/tanh/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/scalar_mul_1/gradients/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/add/gradients/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/einsum_3/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/scalar_mul_1/gradients/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/einsum_2/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/scalar_mul_1/gradients/scalar_mul/parallel_0/mul" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/einsum/parallel_0/einsum/Einsum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/einsum_2/gradients/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/scalar_mul_1/gradients/scalar_mul/parallel_0/mul" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/einsum_1/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/einsum_2/gradients/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/add/gradients/add/parallel_0/Add" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/einsum_2/gradients/einsum_1/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/einsum_1/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/einsum_2/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/einsum/parallel_0/einsum/Einsum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/einsum_1/gradients/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/einsum_2/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/einsum_1/gradients/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/einsum_2/gradients/add/parallel_0/Add" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/einsum_1/gradients/einsum_1/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/scalar_mul/gradients/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.044714998453855515 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/scalar_mul/gradients/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/einsum_1/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/scalar_mul/gradients/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/scalar_mul/gradients/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/einsum_1/gradients/add/parallel_0/Add" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/scalar_mul/gradients/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/einsum/gradients/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/scalar_mul/gradients/add/parallel_0/Add" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/einsum/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,ed->abce" - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/einsum/gradients/einsum_1/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/scalar_mul/gradients/add/parallel_0/Add" - input: "while_loop/while/encoder/block_000/layer_001/einsum_2/parallel_0/Mul" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,abce->ed" - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/einsum/gradients/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_1/einsum/gradients/einsum/parallel_0/einsum/Einsum" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/einsum/gradients/einsum/parallel_0/einsum/Einsum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/einsum_2/gradients/einsum/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/einsum_2/gradients/einsum/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/encoder/block_000/layer_001/cast/parallel_0/Cast" - input: "while_loop/while/encoder/block_000/layer_001/einsum_2/gradients/einsum/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/einsum_2/gradients/einsum/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/einsum_2/gradients/einsum/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/encoder/block_000/layer_001/einsum_2/gradients/einsum/parallel_0/transpose" - input: "while_loop/while/encoder/block_000/layer_001/einsum_2/gradients/einsum/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/einsum_2/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/einsum/gradients/add/parallel_0/Add" - input: "while_loop/while/encoder/block_000/layer_001/einsum_2/gradients/einsum/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/einsum_2/gradients/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/einsum/gradients/add/parallel_0/Add" - input: "while_loop/while/encoder/block_000/layer_001/einsum_1/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/einsum_2/gradients/einsum_1/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/einsum_2/gradients/einsum_1/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/encoder/block_000/layer_001/einsum_2/gradients/einsum_1/parallel_0/Mul" - input: "while_loop/while/encoder/block_000/layer_001/einsum_2/gradients/einsum_1/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/einsum_1/gradients/einsum/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/einsum_1/gradients/einsum/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/encoder/block_000/layer_001/einsum_1/parallel_0/transpose/Enter" - input: "while_loop/while/encoder/block_000/layer_001/einsum_1/gradients/einsum/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/einsum_1/gradients/einsum/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/einsum_1/gradients/einsum/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/encoder/block_000/layer_001/einsum_1/gradients/einsum/parallel_0/transpose" - input: "while_loop/while/encoder/block_000/layer_001/einsum_1/gradients/einsum/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/einsum_1/gradients/einsum/parallel_0/ExpandDims_1/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/einsum_1/gradients/einsum/parallel_0/ExpandDims_1" - op: "ExpandDims" - input: "while_loop/while/encoder/block_000/layer_001/einsum_1/gradients/einsum/parallel_0/ExpandDims" - input: "while_loop/while/encoder/block_000/layer_001/einsum_1/gradients/einsum/parallel_0/ExpandDims_1/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/einsum_1/gradients/einsum/parallel_0/ExpandDims_2/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 2 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/einsum_1/gradients/einsum/parallel_0/ExpandDims_2" - op: "ExpandDims" - input: "while_loop/while/encoder/block_000/layer_001/einsum_1/gradients/einsum/parallel_0/ExpandDims_1" - input: "while_loop/while/encoder/block_000/layer_001/einsum_1/gradients/einsum/parallel_0/ExpandDims_2/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/einsum_1/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_000/layer_001/einsum_2/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/encoder/block_000/layer_001/einsum_1/gradients/einsum/parallel_0/ExpandDims_2" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/einsum_1/gradients/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_000/layer_001/einsum_2/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/encoder/block_000/layer_001/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/einsum_1/gradients/einsum_1/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/einsum_1/gradients/einsum_1/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/encoder/block_000/layer_001/einsum_1/gradients/einsum_1/parallel_0/Mul" - input: "while_loop/while/encoder/block_000/layer_001/einsum_1/gradients/einsum_1/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/einsum/gradients/einsum/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/einsum/gradients/einsum/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/encoder/block_000/layer_001/rsqrt/parallel_0/Rsqrt" - input: "while_loop/while/encoder/block_000/layer_001/einsum/gradients/einsum/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/einsum/gradients/einsum/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/einsum/gradients/einsum/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/encoder/block_000/layer_001/einsum/gradients/einsum/parallel_0/transpose" - input: "while_loop/while/encoder/block_000/layer_001/einsum/gradients/einsum/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/einsum/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_000/layer_001/einsum_1/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/encoder/block_000/layer_001/einsum/gradients/einsum/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/einsum/gradients/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_000/layer_001/einsum_1/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/encoder/block_000/layer_000/add/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/einsum/gradients/einsum_1/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/einsum/gradients/einsum_1/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/encoder/block_000/layer_001/einsum/gradients/einsum_1/parallel_0/Mul" - input: "while_loop/while/encoder/block_000/layer_001/einsum/gradients/einsum_1/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/einsum/gradients/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/encoder/block_001/layer_000/rms_norm/square/gradients/add/parallel_0/Add" - input: "while_loop/while/encoder/block_000/layer_001/einsum/gradients/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/rsqrt/gradients/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: -0.5 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/rsqrt/gradients/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/encoder/block_000/layer_001/einsum/gradients/einsum_1/parallel_0/Sum" - input: "while_loop/while/encoder/block_000/layer_001/rsqrt/gradients/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/rsqrt/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_000/layer_001/rsqrt/gradients/scalar_mul/parallel_0/mul" - input: "while_loop/while/encoder/block_000/layer_001/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/rsqrt/gradients/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_000/layer_001/rsqrt/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/encoder/block_000/layer_001/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/rsqrt/gradients/einsum_2/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_000/layer_001/rsqrt/gradients/einsum_1/parallel_0/Mul" - input: "while_loop/while/encoder/block_000/layer_001/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/rms_norm/reduce_mean/scalar_mul/gradients/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.03125 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/rms_norm/reduce_mean/scalar_mul/gradients/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/encoder/block_000/layer_001/rsqrt/gradients/einsum_2/parallel_0/Mul" - input: "while_loop/while/encoder/block_000/layer_001/rms_norm/reduce_mean/scalar_mul/gradients/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/zeros/shape_as_tensor" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000 \000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/zeros/Const" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/zeros" - op: "Fill" - input: "while_loop/while/encoder/block_000/layer_001/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/zeros/shape_as_tensor" - input: "while_loop/while/encoder/block_000/layer_001/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/encoder/block_000/layer_001/rms_norm/reduce_mean/scalar_mul/gradients/scalar_mul/parallel_0/mul" - input: "while_loop/while/encoder/block_000/layer_001/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/encoder/block_000/layer_001/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/transpose" - input: "while_loop/while/encoder/block_000/layer_001/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/add" - op: "AddV2" - input: "while_loop/while/encoder/block_000/layer_001/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/zeros" - input: "while_loop/while/encoder/block_000/layer_001/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/rms_norm/square/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_000/layer_001/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/add" - input: "while_loop/while/encoder/block_000/layer_000/add/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/rms_norm/square/gradients/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 2.0 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/rms_norm/square/gradients/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/encoder/block_000/layer_001/rms_norm/square/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/encoder/block_000/layer_001/rms_norm/square/gradients/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_001/rms_norm/square/gradients/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/encoder/block_000/layer_001/einsum/gradients/add/parallel_0/Add" - input: "while_loop/while/encoder/block_000/layer_001/rms_norm/square/gradients/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/einsum_9/gradients/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/encoder/block_000/layer_001/rms_norm/square/gradients/add/parallel_0/Add" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/einsum_9/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,ed->abce" - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/einsum_9/gradients/einsum_1/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/encoder/block_000/layer_001/rms_norm/square/gradients/add/parallel_0/Add" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/reshape_7/parallel_0/Reshape" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,abce->ed" - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/reshape_7/gradients/reshape/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 5 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 5 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000\002\000\000\000@\000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/reshape_7/gradients/reshape/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/einsum_9/gradients/einsum/parallel_0/einsum/Einsum" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/reshape_7/gradients/reshape/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/reshape_6/gradients/reshape/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 5 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 5 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000\002\000\000\000@\000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/reshape_6/gradients/reshape/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/reshape_7/gradients/reshape/parallel_0/Reshape" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/reshape_6/gradients/reshape/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/einsum_8/gradients/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/reshape_6/gradients/reshape/parallel_0/Reshape" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/reshape_3/parallel_0/Reshape" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcde,abfde->abcdf" - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/einsum_8/gradients/einsum_1/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/reshape_6/gradients/reshape/parallel_0/Reshape" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/exp/parallel_0/Exp" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcde,abcdf->abfde" - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/exp/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/einsum_8/gradients/einsum/parallel_0/einsum/Einsum" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/exp/parallel_0/Exp" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/add/gradients/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 4 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/add/gradients/reduce/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/exp/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/add/gradients/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/negative/gradients/negative/parallel_0/Neg" - op: "Neg" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/add/gradients/reduce/parallel_0/Sum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/log/gradients/reciprocal/parallel_0/Reciprocal" - op: "Reciprocal" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/parallel_0/Sum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/log/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/negative/gradients/negative/parallel_0/Neg" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/log/gradients/reciprocal/parallel_0/Reciprocal" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/zeros/shape_as_tensor" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 5 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 5 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000\002\000\000\000\000\002\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/zeros/Const" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/zeros" - op: "Fill" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/zeros/shape_as_tensor" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000\003\000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/log/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 4 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/transpose" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/add" - op: "AddV2" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/zeros" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/exp/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/reduce/gradients/broadcast/parallel_0/add" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/exp/parallel_0/Exp" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/add/gradients/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 4 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/add/gradients/reduce/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/exp/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/add/gradients/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/add/gradients/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/exp/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/exp/gradients/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/add_5/gradients/reduce/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 5 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 5 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000\004\000\000\000\003\000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/add_5/gradients/reduce/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/add/gradients/add/parallel_0/Add" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/add_5/gradients/reduce/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/einsum_7/gradients/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/add/gradients/add/parallel_0/Add" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/reshape_2/parallel_0/Reshape" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcde,abedf->abcdf" - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/einsum_7/gradients/einsum_1/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/softmax/reduce_logsumexp/add/gradients/add/parallel_0/Add" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/reshape/parallel_0/Reshape" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 2 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcde,abcdf->abedf" - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/add_4/gradients/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 4 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/add_4/gradients/reduce/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/add_5/gradients/reduce/parallel_0/transpose" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/add_4/gradients/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/add_4/gradients/reduce_1/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/add_4/gradients/reduce_1/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/add_5/gradients/reduce/parallel_0/transpose" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/add_4/gradients/reduce_1/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/add_4/gradients/reduce_1/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\001\000\000\000\000\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/add_4/gradients/reduce_1/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/add_4/gradients/reduce_1/parallel_0/Sum" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/add_4/gradients/reduce_1/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - dim { - size: 2 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/einsum_6/gradients/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/add_4/gradients/reduce_1/parallel_0/transpose" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/einsum_6/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abc,cd->abd" - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/einsum_6/gradients/einsum_1/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/add_4/gradients/reduce_1/parallel_0/transpose" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/one_hot/parallel_0/one_hot" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abc,abd->cd" - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/einsum_6/gradients/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/einsum_6/gradients/einsum_1/parallel_0/einsum/Einsum" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/einsum_6/gradients/einsum_1/parallel_0/einsum/Einsum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/reshape_3/gradients/reshape/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000\200\000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/reshape_3/gradients/reshape/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/einsum_8/gradients/einsum_1/parallel_0/einsum/Einsum" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/reshape_3/gradients/reshape/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/einsum_2/gradients/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/reshape_3/gradients/reshape/parallel_0/Reshape" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/einsum_2/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,ed->abce" - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/einsum_2/gradients/einsum_1/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/reshape_3/gradients/reshape/parallel_0/Reshape" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/reshape_1/parallel_0/Reshape" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,abce->ed" - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/reshape_2/gradients/reshape/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000\200\000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/reshape_2/gradients/reshape/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/einsum_7/gradients/einsum_1/parallel_0/einsum/Einsum" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/reshape_2/gradients/reshape/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/einsum_1/gradients/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/reshape_2/gradients/reshape/parallel_0/Reshape" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/einsum_1/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,ed->abce" - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/einsum_1/gradients/einsum_1/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/reshape_2/gradients/reshape/parallel_0/Reshape" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/reshape_1/parallel_0/Reshape" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,abce->ed" - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/einsum_1/gradients/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/einsum_2/gradients/einsum/parallel_0/einsum/Einsum" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/einsum_1/gradients/einsum/parallel_0/einsum/Einsum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/reshape_1/gradients/reshape/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000 \000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/reshape_1/gradients/reshape/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/einsum_1/gradients/add/parallel_0/Add" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/reshape_1/gradients/reshape/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/reshape/gradients/reshape/parallel_0/Reshape/shape" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000\200\000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/reshape/gradients/reshape/parallel_0/Reshape" - op: "Reshape" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/einsum_7/gradients/einsum/parallel_0/einsum/Einsum" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/reshape/gradients/reshape/parallel_0/Reshape/shape" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/einsum/gradients/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/reshape/gradients/reshape/parallel_0/Reshape" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/einsum/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,ed->abce" - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/einsum/gradients/einsum_1/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/reshape/gradients/reshape/parallel_0/Reshape" - input: "while_loop/while/encoder/block_000/layer_000/einsum_2/parallel_0/Mul" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,abce->ed" - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/SelfAttention/einsum/gradients/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/reshape_1/gradients/reshape/parallel_0/Reshape" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/einsum/gradients/einsum/parallel_0/einsum/Einsum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/einsum_2/gradients/einsum/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/einsum_2/gradients/einsum/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/encoder/block_000/layer_000/cast/parallel_0/Cast" - input: "while_loop/while/encoder/block_000/layer_000/einsum_2/gradients/einsum/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/einsum_2/gradients/einsum/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/einsum_2/gradients/einsum/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/encoder/block_000/layer_000/einsum_2/gradients/einsum/parallel_0/transpose" - input: "while_loop/while/encoder/block_000/layer_000/einsum_2/gradients/einsum/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/einsum_2/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/einsum/gradients/add/parallel_0/Add" - input: "while_loop/while/encoder/block_000/layer_000/einsum_2/gradients/einsum/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/einsum_2/gradients/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/einsum/gradients/add/parallel_0/Add" - input: "while_loop/while/encoder/block_000/layer_000/einsum_1/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/einsum_2/gradients/einsum_1/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/einsum_2/gradients/einsum_1/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/encoder/block_000/layer_000/einsum_2/gradients/einsum_1/parallel_0/Mul" - input: "while_loop/while/encoder/block_000/layer_000/einsum_2/gradients/einsum_1/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/einsum_1/gradients/einsum/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/einsum_1/gradients/einsum/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/encoder/block_000/layer_000/einsum_1/parallel_0/transpose/Enter" - input: "while_loop/while/encoder/block_000/layer_000/einsum_1/gradients/einsum/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/einsum_1/gradients/einsum/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/einsum_1/gradients/einsum/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/encoder/block_000/layer_000/einsum_1/gradients/einsum/parallel_0/transpose" - input: "while_loop/while/encoder/block_000/layer_000/einsum_1/gradients/einsum/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/einsum_1/gradients/einsum/parallel_0/ExpandDims_1/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/einsum_1/gradients/einsum/parallel_0/ExpandDims_1" - op: "ExpandDims" - input: "while_loop/while/encoder/block_000/layer_000/einsum_1/gradients/einsum/parallel_0/ExpandDims" - input: "while_loop/while/encoder/block_000/layer_000/einsum_1/gradients/einsum/parallel_0/ExpandDims_1/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/einsum_1/gradients/einsum/parallel_0/ExpandDims_2/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 2 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/einsum_1/gradients/einsum/parallel_0/ExpandDims_2" - op: "ExpandDims" - input: "while_loop/while/encoder/block_000/layer_000/einsum_1/gradients/einsum/parallel_0/ExpandDims_1" - input: "while_loop/while/encoder/block_000/layer_000/einsum_1/gradients/einsum/parallel_0/ExpandDims_2/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/einsum_1/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_000/layer_000/einsum_2/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/encoder/block_000/layer_000/einsum_1/gradients/einsum/parallel_0/ExpandDims_2" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/einsum_1/gradients/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_000/layer_000/einsum_2/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/encoder/block_000/layer_000/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/einsum_1/gradients/einsum_1/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/einsum_1/gradients/einsum_1/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/encoder/block_000/layer_000/einsum_1/gradients/einsum_1/parallel_0/Mul" - input: "while_loop/while/encoder/block_000/layer_000/einsum_1/gradients/einsum_1/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/einsum/gradients/einsum/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/einsum/gradients/einsum/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/encoder/block_000/layer_000/rsqrt/parallel_0/Rsqrt" - input: "while_loop/while/encoder/block_000/layer_000/einsum/gradients/einsum/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/einsum/gradients/einsum/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/einsum/gradients/einsum/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/encoder/block_000/layer_000/einsum/gradients/einsum/parallel_0/transpose" - input: "while_loop/while/encoder/block_000/layer_000/einsum/gradients/einsum/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/einsum/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_000/layer_000/einsum_1/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/encoder/block_000/layer_000/einsum/gradients/einsum/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/einsum/gradients/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_000/layer_000/einsum_1/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/encoder/einsum/parallel_0/einsum/Einsum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/einsum/gradients/einsum_1/parallel_0/Sum/reduction_indices" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/einsum/gradients/einsum_1/parallel_0/Sum" - op: "Sum" - input: "while_loop/while/encoder/block_000/layer_000/einsum/gradients/einsum_1/parallel_0/Mul" - input: "while_loop/while/encoder/block_000/layer_000/einsum/gradients/einsum_1/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/einsum/gradients/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/encoder/block_000/layer_001/rms_norm/square/gradients/add/parallel_0/Add" - input: "while_loop/while/encoder/block_000/layer_000/einsum/gradients/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/rsqrt/gradients/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: -0.5 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/rsqrt/gradients/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/encoder/block_000/layer_000/einsum/gradients/einsum_1/parallel_0/Sum" - input: "while_loop/while/encoder/block_000/layer_000/rsqrt/gradients/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/rsqrt/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_000/layer_000/rsqrt/gradients/scalar_mul/parallel_0/mul" - input: "while_loop/while/encoder/block_000/layer_000/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/rsqrt/gradients/einsum_1/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_000/layer_000/rsqrt/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/encoder/block_000/layer_000/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/rsqrt/gradients/einsum_2/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_000/layer_000/rsqrt/gradients/einsum_1/parallel_0/Mul" - input: "while_loop/while/encoder/block_000/layer_000/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/rms_norm/reduce_mean/scalar_mul/gradients/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.03125 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/rms_norm/reduce_mean/scalar_mul/gradients/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/encoder/block_000/layer_000/rsqrt/gradients/einsum_2/parallel_0/Mul" - input: "while_loop/while/encoder/block_000/layer_000/rms_norm/reduce_mean/scalar_mul/gradients/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/zeros/shape_as_tensor" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 4 - } - } - tensor_content: "\001\000\000\000\010\000\000\000\000\002\000\000 \000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/zeros/Const" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/zeros" - op: "Fill" - input: "while_loop/while/encoder/block_000/layer_000/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/zeros/shape_as_tensor" - input: "while_loop/while/encoder/block_000/layer_000/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/transpose/perm" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 3 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/transpose" - op: "Transpose" - input: "while_loop/while/encoder/block_000/layer_000/rms_norm/reduce_mean/scalar_mul/gradients/scalar_mul/parallel_0/mul" - input: "while_loop/while/encoder/block_000/layer_000/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/transpose/perm" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tperm" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/ExpandDims/dim" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/ExpandDims" - op: "ExpandDims" - input: "while_loop/while/encoder/block_000/layer_000/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/transpose" - input: "while_loop/while/encoder/block_000/layer_000/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/add" - op: "AddV2" - input: "while_loop/while/encoder/block_000/layer_000/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/zeros" - input: "while_loop/while/encoder/block_000/layer_000/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/ExpandDims" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/rms_norm/square/gradients/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/encoder/block_000/layer_000/rms_norm/reduce_mean/reduce/gradients/broadcast/parallel_0/add" - input: "while_loop/while/encoder/einsum/parallel_0/einsum/Einsum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/rms_norm/square/gradients/scalar_mul/parallel_0/mul/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 2.0 - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/rms_norm/square/gradients/scalar_mul/parallel_0/mul" - op: "Mul" - input: "while_loop/while/encoder/block_000/layer_000/rms_norm/square/gradients/einsum/parallel_0/Mul" - input: "while_loop/while/encoder/block_000/layer_000/rms_norm/square/gradients/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/block_000/layer_000/rms_norm/square/gradients/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/encoder/block_000/layer_000/einsum/gradients/add/parallel_0/Add" - input: "while_loop/while/encoder/block_000/layer_000/rms_norm/square/gradients/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/encoder/einsum/gradients/einsum/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/encoder/block_000/layer_000/rms_norm/square/gradients/add/parallel_0/Add" - input: "while_loop/while/encoder/einsum/parallel_0/einsum/Einsum/Enter" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 8 - } - dim { - size: 512 - } - dim { - size: 32128 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,ed->abce" - } - } -} -node { - name: "while_loop/while/encoder/einsum/gradients/einsum_1/parallel_0/einsum/Einsum" - op: "Einsum" - input: "while_loop/while/encoder/block_000/layer_000/rms_norm/square/gradients/add/parallel_0/Add" - input: "while_loop/while/encoder/one_hot/parallel_0/one_hot" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "equation" - value { - s: "abcd,abce->ed" - } - } -} -node { - name: "while_loop/while/encoder/einsum/gradients/add/parallel_0/Add" - op: "Add" - input: "while_loop/while/decoder/einsum/gradients/einsum_1/parallel_0/einsum/Einsum" - input: "while_loop/while/encoder/einsum/gradients/einsum_1/parallel_0/einsum/Einsum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/scalar_add_1/parallel_0/add/y" - op: "Const" - input: "^while_loop/while/Identity" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "while_loop/while/scalar_add_1/parallel_0/add" - op: "AddV2" - input: "while_loop/while/Identity" - input: "while_loop/while/scalar_add_1/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "while_loop/while/add" - op: "AddV2" - input: "while_loop/while/scalar_add/parallel_0/add" - input: "while_loop/while/Identity_1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "while_loop/while/add_1" - op: "AddV2" - input: "while_loop/while/encoder/einsum/gradients/add/parallel_0/Add" - input: "while_loop/while/Identity_2" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/add_2" - op: "AddV2" - input: "while_loop/while/encoder/block_000/layer_000/einsum_1/gradients/einsum_1/parallel_0/Sum" - input: "while_loop/while/Identity_3" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/add_3" - op: "AddV2" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/einsum/gradients/einsum_1/parallel_0/einsum/Einsum" - input: "while_loop/while/Identity_4" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/add_4" - op: "AddV2" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/einsum_1/gradients/einsum_1/parallel_0/einsum/Einsum" - input: "while_loop/while/Identity_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/add_5" - op: "AddV2" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/einsum_2/gradients/einsum_1/parallel_0/einsum/Einsum" - input: "while_loop/while/Identity_6" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/add_6" - op: "AddV2" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/einsum_9/gradients/einsum_1/parallel_0/einsum/Einsum" - input: "while_loop/while/Identity_7" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/add_7" - op: "AddV2" - input: "while_loop/while/encoder/block_000/layer_000/SelfAttention/einsum_6/gradients/add/parallel_0/Add" - input: "while_loop/while/Identity_8" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/add_8" - op: "AddV2" - input: "while_loop/while/encoder/block_000/layer_001/einsum_1/gradients/einsum_1/parallel_0/Sum" - input: "while_loop/while/Identity_9" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/add_9" - op: "AddV2" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_0/einsum/gradients/einsum_1/parallel_0/einsum/Einsum" - input: "while_loop/while/Identity_10" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/add_10" - op: "AddV2" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wi_1/einsum/gradients/einsum_1/parallel_0/einsum/Einsum" - input: "while_loop/while/Identity_11" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/add_11" - op: "AddV2" - input: "while_loop/while/encoder/block_000/layer_001/DenseReluDense/wo/einsum/gradients/einsum_1/parallel_0/einsum/Einsum" - input: "while_loop/while/Identity_12" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/add_12" - op: "AddV2" - input: "while_loop/while/encoder/block_001/layer_000/einsum_1/gradients/einsum_1/parallel_0/Sum" - input: "while_loop/while/Identity_13" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/add_13" - op: "AddV2" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/einsum/gradients/einsum_1/parallel_0/einsum/Einsum" - input: "while_loop/while/Identity_14" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/add_14" - op: "AddV2" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/einsum_1/gradients/einsum_1/parallel_0/einsum/Einsum" - input: "while_loop/while/Identity_15" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/add_15" - op: "AddV2" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/einsum_2/gradients/einsum_1/parallel_0/einsum/Einsum" - input: "while_loop/while/Identity_16" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/add_16" - op: "AddV2" - input: "while_loop/while/encoder/block_001/layer_000/SelfAttention/einsum_9/gradients/einsum_1/parallel_0/einsum/Einsum" - input: "while_loop/while/Identity_17" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/add_17" - op: "AddV2" - input: "while_loop/while/encoder/block_001/layer_001/einsum_1/gradients/einsum_1/parallel_0/Sum" - input: "while_loop/while/Identity_18" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/add_18" - op: "AddV2" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_0/einsum/gradients/einsum_1/parallel_0/einsum/Einsum" - input: "while_loop/while/Identity_19" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/add_19" - op: "AddV2" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wi_1/einsum/gradients/einsum_1/parallel_0/einsum/Einsum" - input: "while_loop/while/Identity_20" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/add_20" - op: "AddV2" - input: "while_loop/while/encoder/block_001/layer_001/DenseReluDense/wo/einsum/gradients/einsum_1/parallel_0/einsum/Einsum" - input: "while_loop/while/Identity_21" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/add_21" - op: "AddV2" - input: "while_loop/while/encoder/einsum_2/gradients/einsum_1/parallel_0/Sum" - input: "while_loop/while/Identity_22" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/add_22" - op: "AddV2" - input: "while_loop/while/decoder/block_000/layer_000/einsum_1/gradients/einsum_1/parallel_0/Sum" - input: "while_loop/while/Identity_23" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/add_23" - op: "AddV2" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/einsum/gradients/einsum_1/parallel_0/einsum/Einsum" - input: "while_loop/while/Identity_24" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/add_24" - op: "AddV2" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/einsum_1/gradients/einsum_1/parallel_0/einsum/Einsum" - input: "while_loop/while/Identity_25" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/add_25" - op: "AddV2" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/einsum_2/gradients/einsum_1/parallel_0/einsum/Einsum" - input: "while_loop/while/Identity_26" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/add_26" - op: "AddV2" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/einsum_8/gradients/einsum_1/parallel_0/einsum/Einsum" - input: "while_loop/while/Identity_27" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/add_27" - op: "AddV2" - input: "while_loop/while/decoder/block_000/layer_000/SelfAttention/einsum_5/gradients/add/parallel_0/Add" - input: "while_loop/while/Identity_28" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/add_28" - op: "AddV2" - input: "while_loop/while/decoder/block_000/layer_001/einsum_1/gradients/einsum_1/parallel_0/Sum" - input: "while_loop/while/Identity_29" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/add_29" - op: "AddV2" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/einsum/gradients/einsum_1/parallel_0/einsum/Einsum" - input: "while_loop/while/Identity_30" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/add_30" - op: "AddV2" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/einsum_1/gradients/einsum_1/parallel_0/einsum/Einsum" - input: "while_loop/while/Identity_31" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/add_31" - op: "AddV2" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/einsum_2/gradients/einsum_1/parallel_0/einsum/Einsum" - input: "while_loop/while/Identity_32" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/add_32" - op: "AddV2" - input: "while_loop/while/decoder/block_000/layer_001/EncDecAttention/einsum_5/gradients/einsum_1/parallel_0/einsum/Einsum" - input: "while_loop/while/Identity_33" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/add_33" - op: "AddV2" - input: "while_loop/while/decoder/block_000/layer_002/einsum_1/gradients/einsum_1/parallel_0/Sum" - input: "while_loop/while/Identity_34" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/add_34" - op: "AddV2" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_0/einsum/gradients/einsum_1/parallel_0/einsum/Einsum" - input: "while_loop/while/Identity_35" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/add_35" - op: "AddV2" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wi_1/einsum/gradients/einsum_1/parallel_0/einsum/Einsum" - input: "while_loop/while/Identity_36" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/add_36" - op: "AddV2" - input: "while_loop/while/decoder/block_000/layer_002/DenseReluDense/wo/einsum/gradients/einsum_1/parallel_0/einsum/Einsum" - input: "while_loop/while/Identity_37" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/add_37" - op: "AddV2" - input: "while_loop/while/decoder/block_001/layer_000/einsum_1/gradients/einsum_1/parallel_0/Sum" - input: "while_loop/while/Identity_38" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/add_38" - op: "AddV2" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/einsum/gradients/einsum_1/parallel_0/einsum/Einsum" - input: "while_loop/while/Identity_39" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/add_39" - op: "AddV2" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/einsum_1/gradients/einsum_1/parallel_0/einsum/Einsum" - input: "while_loop/while/Identity_40" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/add_40" - op: "AddV2" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/einsum_2/gradients/einsum_1/parallel_0/einsum/Einsum" - input: "while_loop/while/Identity_41" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/add_41" - op: "AddV2" - input: "while_loop/while/decoder/block_001/layer_000/SelfAttention/einsum_8/gradients/einsum_1/parallel_0/einsum/Einsum" - input: "while_loop/while/Identity_42" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/add_42" - op: "AddV2" - input: "while_loop/while/decoder/block_001/layer_001/einsum_1/gradients/einsum_1/parallel_0/Sum" - input: "while_loop/while/Identity_43" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/add_43" - op: "AddV2" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/einsum/gradients/einsum_1/parallel_0/einsum/Einsum" - input: "while_loop/while/Identity_44" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/add_44" - op: "AddV2" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/einsum_1/gradients/einsum_1/parallel_0/einsum/Einsum" - input: "while_loop/while/Identity_45" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/add_45" - op: "AddV2" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/einsum_2/gradients/einsum_1/parallel_0/einsum/Einsum" - input: "while_loop/while/Identity_46" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/add_46" - op: "AddV2" - input: "while_loop/while/decoder/block_001/layer_001/EncDecAttention/einsum_5/gradients/einsum_1/parallel_0/einsum/Einsum" - input: "while_loop/while/Identity_47" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/add_47" - op: "AddV2" - input: "while_loop/while/decoder/block_001/layer_002/einsum_1/gradients/einsum_1/parallel_0/Sum" - input: "while_loop/while/Identity_48" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/add_48" - op: "AddV2" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_0/einsum/gradients/einsum_1/parallel_0/einsum/Einsum" - input: "while_loop/while/Identity_49" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/add_49" - op: "AddV2" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wi_1/einsum/gradients/einsum_1/parallel_0/einsum/Einsum" - input: "while_loop/while/Identity_50" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/add_50" - op: "AddV2" - input: "while_loop/while/decoder/block_001/layer_002/DenseReluDense/wo/einsum/gradients/einsum_1/parallel_0/einsum/Einsum" - input: "while_loop/while/Identity_51" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/add_51" - op: "AddV2" - input: "while_loop/while/decoder/einsum_2/gradients/einsum_1/parallel_0/Sum" - input: "while_loop/while/Identity_52" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/add_52" - op: "AddV2" - input: "while_loop/while/decoder/logits/einsum/gradients/einsum_1/parallel_0/einsum/Einsum" - input: "while_loop/while/Identity_53" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 32128 - } - } - } - } - } -} -node { - name: "while_loop/while/NextIteration" - op: "NextIteration" - input: "while_loop/while/scalar_add_1/parallel_0/add" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "while_loop/while/NextIteration_1" - op: "NextIteration" - input: "while_loop/while/add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "while_loop/while/NextIteration_2" - op: "NextIteration" - input: "while_loop/while/add_1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/NextIteration_3" - op: "NextIteration" - input: "while_loop/while/add_2" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/NextIteration_4" - op: "NextIteration" - input: "while_loop/while/add_3" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/NextIteration_5" - op: "NextIteration" - input: "while_loop/while/add_4" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/NextIteration_6" - op: "NextIteration" - input: "while_loop/while/add_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/NextIteration_7" - op: "NextIteration" - input: "while_loop/while/add_6" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/NextIteration_8" - op: "NextIteration" - input: "while_loop/while/add_7" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/NextIteration_9" - op: "NextIteration" - input: "while_loop/while/add_8" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/NextIteration_10" - op: "NextIteration" - input: "while_loop/while/add_9" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/NextIteration_11" - op: "NextIteration" - input: "while_loop/while/add_10" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/NextIteration_12" - op: "NextIteration" - input: "while_loop/while/add_11" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/NextIteration_13" - op: "NextIteration" - input: "while_loop/while/add_12" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/NextIteration_14" - op: "NextIteration" - input: "while_loop/while/add_13" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/NextIteration_15" - op: "NextIteration" - input: "while_loop/while/add_14" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/NextIteration_16" - op: "NextIteration" - input: "while_loop/while/add_15" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/NextIteration_17" - op: "NextIteration" - input: "while_loop/while/add_16" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/NextIteration_18" - op: "NextIteration" - input: "while_loop/while/add_17" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/NextIteration_19" - op: "NextIteration" - input: "while_loop/while/add_18" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/NextIteration_20" - op: "NextIteration" - input: "while_loop/while/add_19" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/NextIteration_21" - op: "NextIteration" - input: "while_loop/while/add_20" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/NextIteration_22" - op: "NextIteration" - input: "while_loop/while/add_21" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/NextIteration_23" - op: "NextIteration" - input: "while_loop/while/add_22" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/NextIteration_24" - op: "NextIteration" - input: "while_loop/while/add_23" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/NextIteration_25" - op: "NextIteration" - input: "while_loop/while/add_24" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/NextIteration_26" - op: "NextIteration" - input: "while_loop/while/add_25" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/NextIteration_27" - op: "NextIteration" - input: "while_loop/while/add_26" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/NextIteration_28" - op: "NextIteration" - input: "while_loop/while/add_27" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/NextIteration_29" - op: "NextIteration" - input: "while_loop/while/add_28" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/NextIteration_30" - op: "NextIteration" - input: "while_loop/while/add_29" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/NextIteration_31" - op: "NextIteration" - input: "while_loop/while/add_30" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/NextIteration_32" - op: "NextIteration" - input: "while_loop/while/add_31" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/NextIteration_33" - op: "NextIteration" - input: "while_loop/while/add_32" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/NextIteration_34" - op: "NextIteration" - input: "while_loop/while/add_33" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/NextIteration_35" - op: "NextIteration" - input: "while_loop/while/add_34" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/NextIteration_36" - op: "NextIteration" - input: "while_loop/while/add_35" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/NextIteration_37" - op: "NextIteration" - input: "while_loop/while/add_36" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/NextIteration_38" - op: "NextIteration" - input: "while_loop/while/add_37" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/NextIteration_39" - op: "NextIteration" - input: "while_loop/while/add_38" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/NextIteration_40" - op: "NextIteration" - input: "while_loop/while/add_39" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/NextIteration_41" - op: "NextIteration" - input: "while_loop/while/add_40" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/NextIteration_42" - op: "NextIteration" - input: "while_loop/while/add_41" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/NextIteration_43" - op: "NextIteration" - input: "while_loop/while/add_42" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/NextIteration_44" - op: "NextIteration" - input: "while_loop/while/add_43" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/NextIteration_45" - op: "NextIteration" - input: "while_loop/while/add_44" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/NextIteration_46" - op: "NextIteration" - input: "while_loop/while/add_45" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/NextIteration_47" - op: "NextIteration" - input: "while_loop/while/add_46" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/NextIteration_48" - op: "NextIteration" - input: "while_loop/while/add_47" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/NextIteration_49" - op: "NextIteration" - input: "while_loop/while/add_48" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/NextIteration_50" - op: "NextIteration" - input: "while_loop/while/add_49" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/NextIteration_51" - op: "NextIteration" - input: "while_loop/while/add_50" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/NextIteration_52" - op: "NextIteration" - input: "while_loop/while/add_51" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/NextIteration_53" - op: "NextIteration" - input: "while_loop/while/add_52" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 32128 - } - } - } - } - } -} -node { - name: "while_loop/while/Exit" - op: "Exit" - input: "while_loop/while/Switch" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "while_loop/while/Exit_1" - op: "Exit" - input: "while_loop/while/Switch_1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "while_loop/while/Exit_2" - op: "Exit" - input: "while_loop/while/Switch_2" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Exit_3" - op: "Exit" - input: "while_loop/while/Switch_3" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Exit_4" - op: "Exit" - input: "while_loop/while/Switch_4" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/Exit_5" - op: "Exit" - input: "while_loop/while/Switch_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/Exit_6" - op: "Exit" - input: "while_loop/while/Switch_6" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/Exit_7" - op: "Exit" - input: "while_loop/while/Switch_7" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Exit_8" - op: "Exit" - input: "while_loop/while/Switch_8" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Exit_9" - op: "Exit" - input: "while_loop/while/Switch_9" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Exit_10" - op: "Exit" - input: "while_loop/while/Switch_10" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/Exit_11" - op: "Exit" - input: "while_loop/while/Switch_11" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/Exit_12" - op: "Exit" - input: "while_loop/while/Switch_12" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Exit_13" - op: "Exit" - input: "while_loop/while/Switch_13" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Exit_14" - op: "Exit" - input: "while_loop/while/Switch_14" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/Exit_15" - op: "Exit" - input: "while_loop/while/Switch_15" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/Exit_16" - op: "Exit" - input: "while_loop/while/Switch_16" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/Exit_17" - op: "Exit" - input: "while_loop/while/Switch_17" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Exit_18" - op: "Exit" - input: "while_loop/while/Switch_18" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Exit_19" - op: "Exit" - input: "while_loop/while/Switch_19" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/Exit_20" - op: "Exit" - input: "while_loop/while/Switch_20" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/Exit_21" - op: "Exit" - input: "while_loop/while/Switch_21" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Exit_22" - op: "Exit" - input: "while_loop/while/Switch_22" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Exit_23" - op: "Exit" - input: "while_loop/while/Switch_23" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Exit_24" - op: "Exit" - input: "while_loop/while/Switch_24" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/Exit_25" - op: "Exit" - input: "while_loop/while/Switch_25" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/Exit_26" - op: "Exit" - input: "while_loop/while/Switch_26" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/Exit_27" - op: "Exit" - input: "while_loop/while/Switch_27" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Exit_28" - op: "Exit" - input: "while_loop/while/Switch_28" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Exit_29" - op: "Exit" - input: "while_loop/while/Switch_29" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Exit_30" - op: "Exit" - input: "while_loop/while/Switch_30" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/Exit_31" - op: "Exit" - input: "while_loop/while/Switch_31" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/Exit_32" - op: "Exit" - input: "while_loop/while/Switch_32" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/Exit_33" - op: "Exit" - input: "while_loop/while/Switch_33" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Exit_34" - op: "Exit" - input: "while_loop/while/Switch_34" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Exit_35" - op: "Exit" - input: "while_loop/while/Switch_35" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/Exit_36" - op: "Exit" - input: "while_loop/while/Switch_36" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/Exit_37" - op: "Exit" - input: "while_loop/while/Switch_37" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Exit_38" - op: "Exit" - input: "while_loop/while/Switch_38" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Exit_39" - op: "Exit" - input: "while_loop/while/Switch_39" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/Exit_40" - op: "Exit" - input: "while_loop/while/Switch_40" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/Exit_41" - op: "Exit" - input: "while_loop/while/Switch_41" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/Exit_42" - op: "Exit" - input: "while_loop/while/Switch_42" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Exit_43" - op: "Exit" - input: "while_loop/while/Switch_43" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Exit_44" - op: "Exit" - input: "while_loop/while/Switch_44" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/Exit_45" - op: "Exit" - input: "while_loop/while/Switch_45" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/Exit_46" - op: "Exit" - input: "while_loop/while/Switch_46" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "while_loop/while/Exit_47" - op: "Exit" - input: "while_loop/while/Switch_47" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Exit_48" - op: "Exit" - input: "while_loop/while/Switch_48" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Exit_49" - op: "Exit" - input: "while_loop/while/Switch_49" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/Exit_50" - op: "Exit" - input: "while_loop/while/Switch_50" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "while_loop/while/Exit_51" - op: "Exit" - input: "while_loop/while/Switch_51" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Exit_52" - op: "Exit" - input: "while_loop/while/Switch_52" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "while_loop/while/Exit_53" - op: "Exit" - input: "while_loop/while/Switch_53" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 32128 - } - } - } - } - } -} -node { - name: "shared/embedding_slot_v_1/group_deps" - op: "NoOp" -} -node { - name: "shared/embedding_slot_v_1/group_deps_1" - op: "NoOp" -} -node { - name: "shared/embedding/adafactor/square/parallel_0/Square" - op: "Square" - input: "while_loop/while/Exit_2" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "shared/embedding/adafactor/scalar_add/parallel_0/add/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0000000031710769e-30 - } - } - } -} -node { - name: "shared/embedding/adafactor/scalar_add/parallel_0/add" - op: "AddV2" - input: "shared/embedding/adafactor/square/parallel_0/Square" - input: "shared/embedding/adafactor/scalar_add/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "shared/embedding/adafactor/square_1/parallel_0/Square" - op: "Square" - input: "shared/embedding_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "shared/embedding/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "shared/embedding/adafactor/reduce_mean/reduce/parallel_0/Sum" - op: "Sum" - input: "shared/embedding/adafactor/square_1/parallel_0/Square" - input: "shared/embedding/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "shared/embedding/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 9.72671841736883e-07 - } - } - } -} -node { - name: "shared/embedding/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - op: "Mul" - input: "shared/embedding/adafactor/reduce_mean/reduce/parallel_0/Sum" - input: "shared/embedding/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "shared/embedding/adafactor/sqrt/parallel_0/Sqrt" - op: "Sqrt" - input: "shared/embedding/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "shared/embedding/adafactor/add_1/parallel_0/Maximum" - op: "Maximum" - input: "shared/embedding/adafactor/sqrt/parallel_0/Sqrt" - input: "shared/embedding/adafactor/maximum/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "shared/embedding/adafactor/scalar_mul/parallel_0/mul" - op: "Mul" - input: "shared/embedding/adafactor/add_1/parallel_0/Maximum" - input: "mul_4" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "shared/embedding/adafactor/scalar_mul_1/parallel_0/mul" - op: "Mul" - input: "shared/embedding_slot_v/read" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "shared/embedding/adafactor/scalar_mul_2/parallel_0/mul" - op: "Mul" - input: "shared/embedding/adafactor/scalar_add/parallel_0/add" - input: "shared/embedding/adafactor/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "shared/embedding/adafactor/add_1_1/parallel_0/Add" - op: "Add" - input: "shared/embedding/adafactor/scalar_mul_1/parallel_0/mul" - input: "shared/embedding/adafactor/scalar_mul_2/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "shared/embedding/adafactor/rsqrt/parallel_0/Rsqrt" - op: "Rsqrt" - input: "shared/embedding/adafactor/add_1_1/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "shared/embedding/adafactor/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/Exit_2" - input: "shared/embedding/adafactor/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "shared/embedding/adafactor/square_2/parallel_0/Square" - op: "Square" - input: "shared/embedding/adafactor/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "shared/embedding/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "shared/embedding/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - op: "Sum" - input: "shared/embedding/adafactor/square_2/parallel_0/Square" - input: "shared/embedding/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "shared/embedding/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 9.72671841736883e-07 - } - } - } -} -node { - name: "shared/embedding/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - op: "Mul" - input: "shared/embedding/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - input: "shared/embedding/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "shared/embedding/adafactor/sqrt_1/parallel_0/Sqrt" - op: "Sqrt" - input: "shared/embedding/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "shared/embedding/adafactor/scalar_mul_3/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "shared/embedding/adafactor/scalar_mul_3/parallel_0/mul" - op: "Mul" - input: "shared/embedding/adafactor/sqrt_1/parallel_0/Sqrt" - input: "shared/embedding/adafactor/scalar_mul_3/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "shared/embedding/adafactor/add_2/parallel_0/Maximum" - op: "Maximum" - input: "shared/embedding/adafactor/maximum_1/Const" - input: "shared/embedding/adafactor/scalar_mul_3/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "shared/embedding/adafactor/reciprocal/parallel_0/Reciprocal" - op: "Reciprocal" - input: "shared/embedding/adafactor/add_2/parallel_0/Maximum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "shared/embedding/adafactor/einsum_1/parallel_0/Mul" - op: "Mul" - input: "shared/embedding/adafactor/einsum/parallel_0/Mul" - input: "shared/embedding/adafactor/reciprocal/parallel_0/Reciprocal" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "shared/embedding/adafactor/einsum_2/parallel_0/Mul" - op: "Mul" - input: "shared/embedding/adafactor/einsum_1/parallel_0/Mul" - input: "shared/embedding/adafactor/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/rms_norm/scale_slot_v_1/group_deps" - op: "NoOp" -} -node { - name: "encoder/block_000/layer_000/rms_norm/scale_slot_v_1/group_deps_1" - op: "NoOp" -} -node { - name: "encoder/block_000/layer_000/rms_norm/scale/adafactor/square/parallel_0/Square" - op: "Square" - input: "while_loop/while/Exit_3" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/rms_norm/scale/adafactor/scalar_add/parallel_0/add/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0000000031710769e-30 - } - } - } -} -node { - name: "encoder/block_000/layer_000/rms_norm/scale/adafactor/scalar_add/parallel_0/add" - op: "AddV2" - input: "encoder/block_000/layer_000/rms_norm/scale/adafactor/square/parallel_0/Square" - input: "encoder/block_000/layer_000/rms_norm/scale/adafactor/scalar_add/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/rms_norm/scale/adafactor/square_1/parallel_0/Square" - op: "Square" - input: "encoder/block_000/layer_000/rms_norm/scale_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/rms_norm/scale/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "encoder/block_000/layer_000/rms_norm/scale/adafactor/reduce_mean/reduce/parallel_0/Sum" - op: "Sum" - input: "encoder/block_000/layer_000/rms_norm/scale/adafactor/square_1/parallel_0/Square" - input: "encoder/block_000/layer_000/rms_norm/scale/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "encoder/block_000/layer_000/rms_norm/scale/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.03125 - } - } - } -} -node { - name: "encoder/block_000/layer_000/rms_norm/scale/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - op: "Mul" - input: "encoder/block_000/layer_000/rms_norm/scale/adafactor/reduce_mean/reduce/parallel_0/Sum" - input: "encoder/block_000/layer_000/rms_norm/scale/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/rms_norm/scale/adafactor/sqrt/parallel_0/Sqrt" - op: "Sqrt" - input: "encoder/block_000/layer_000/rms_norm/scale/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/rms_norm/scale/adafactor/add_1/parallel_0/Maximum" - op: "Maximum" - input: "encoder/block_000/layer_000/rms_norm/scale/adafactor/sqrt/parallel_0/Sqrt" - input: "encoder/block_000/layer_000/rms_norm/scale/adafactor/maximum/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/rms_norm/scale/adafactor/scalar_mul/parallel_0/mul" - op: "Mul" - input: "encoder/block_000/layer_000/rms_norm/scale/adafactor/add_1/parallel_0/Maximum" - input: "mul_4" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/rms_norm/scale/adafactor/scalar_mul_1/parallel_0/mul" - op: "Mul" - input: "encoder/block_000/layer_000/rms_norm/scale_slot_v/read" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/rms_norm/scale/adafactor/scalar_mul_2/parallel_0/mul" - op: "Mul" - input: "encoder/block_000/layer_000/rms_norm/scale/adafactor/scalar_add/parallel_0/add" - input: "encoder/block_000/layer_000/rms_norm/scale/adafactor/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/rms_norm/scale/adafactor/add_1_1/parallel_0/Add" - op: "Add" - input: "encoder/block_000/layer_000/rms_norm/scale/adafactor/scalar_mul_1/parallel_0/mul" - input: "encoder/block_000/layer_000/rms_norm/scale/adafactor/scalar_mul_2/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/rms_norm/scale/adafactor/rsqrt/parallel_0/Rsqrt" - op: "Rsqrt" - input: "encoder/block_000/layer_000/rms_norm/scale/adafactor/add_1_1/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/rms_norm/scale/adafactor/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/Exit_3" - input: "encoder/block_000/layer_000/rms_norm/scale/adafactor/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/rms_norm/scale/adafactor/square_2/parallel_0/Square" - op: "Square" - input: "encoder/block_000/layer_000/rms_norm/scale/adafactor/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/rms_norm/scale/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "encoder/block_000/layer_000/rms_norm/scale/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - op: "Sum" - input: "encoder/block_000/layer_000/rms_norm/scale/adafactor/square_2/parallel_0/Square" - input: "encoder/block_000/layer_000/rms_norm/scale/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "encoder/block_000/layer_000/rms_norm/scale/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.03125 - } - } - } -} -node { - name: "encoder/block_000/layer_000/rms_norm/scale/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - op: "Mul" - input: "encoder/block_000/layer_000/rms_norm/scale/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - input: "encoder/block_000/layer_000/rms_norm/scale/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/rms_norm/scale/adafactor/sqrt_1/parallel_0/Sqrt" - op: "Sqrt" - input: "encoder/block_000/layer_000/rms_norm/scale/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/rms_norm/scale/adafactor/scalar_mul_3/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "encoder/block_000/layer_000/rms_norm/scale/adafactor/scalar_mul_3/parallel_0/mul" - op: "Mul" - input: "encoder/block_000/layer_000/rms_norm/scale/adafactor/sqrt_1/parallel_0/Sqrt" - input: "encoder/block_000/layer_000/rms_norm/scale/adafactor/scalar_mul_3/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/rms_norm/scale/adafactor/add_2/parallel_0/Maximum" - op: "Maximum" - input: "encoder/block_000/layer_000/rms_norm/scale/adafactor/maximum_1/Const" - input: "encoder/block_000/layer_000/rms_norm/scale/adafactor/scalar_mul_3/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/rms_norm/scale/adafactor/reciprocal/parallel_0/Reciprocal" - op: "Reciprocal" - input: "encoder/block_000/layer_000/rms_norm/scale/adafactor/add_2/parallel_0/Maximum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/rms_norm/scale/adafactor/einsum_1/parallel_0/Mul" - op: "Mul" - input: "encoder/block_000/layer_000/rms_norm/scale/adafactor/einsum/parallel_0/Mul" - input: "encoder/block_000/layer_000/rms_norm/scale/adafactor/reciprocal/parallel_0/Reciprocal" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/rms_norm/scale/adafactor/einsum_2/parallel_0/Mul" - op: "Mul" - input: "encoder/block_000/layer_000/rms_norm/scale/adafactor/einsum_1/parallel_0/Mul" - input: "encoder/block_000/layer_000/rms_norm/scale/adafactor/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/q_slot_v_1/group_deps" - op: "NoOp" -} -node { - name: "encoder/block_000/layer_000/SelfAttention/q_slot_v_1/group_deps_1" - op: "NoOp" -} -node { - name: "encoder/block_000/layer_000/SelfAttention/q/adafactor/square/parallel_0/Square" - op: "Square" - input: "while_loop/while/Exit_4" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/q/adafactor/scalar_add/parallel_0/add/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0000000031710769e-30 - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/q/adafactor/scalar_add/parallel_0/add" - op: "AddV2" - input: "encoder/block_000/layer_000/SelfAttention/q/adafactor/square/parallel_0/Square" - input: "encoder/block_000/layer_000/SelfAttention/q/adafactor/scalar_add/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/q/adafactor/square_1/parallel_0/Square" - op: "Square" - input: "encoder/block_000/layer_000/SelfAttention/q_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/q/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/q/adafactor/reduce_mean/reduce/parallel_0/Sum" - op: "Sum" - input: "encoder/block_000/layer_000/SelfAttention/q/adafactor/square_1/parallel_0/Square" - input: "encoder/block_000/layer_000/SelfAttention/q/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/q/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.000244140625 - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/q/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - op: "Mul" - input: "encoder/block_000/layer_000/SelfAttention/q/adafactor/reduce_mean/reduce/parallel_0/Sum" - input: "encoder/block_000/layer_000/SelfAttention/q/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/q/adafactor/sqrt/parallel_0/Sqrt" - op: "Sqrt" - input: "encoder/block_000/layer_000/SelfAttention/q/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/q/adafactor/add_1/parallel_0/Maximum" - op: "Maximum" - input: "encoder/block_000/layer_000/SelfAttention/q/adafactor/sqrt/parallel_0/Sqrt" - input: "encoder/block_000/layer_000/SelfAttention/q/adafactor/maximum/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/q/adafactor/scalar_mul/parallel_0/mul" - op: "Mul" - input: "encoder/block_000/layer_000/SelfAttention/q/adafactor/add_1/parallel_0/Maximum" - input: "mul_4" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/q/adafactor/scalar_mul_1/parallel_0/mul" - op: "Mul" - input: "encoder/block_000/layer_000/SelfAttention/q_slot_v/read" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/q/adafactor/scalar_mul_2/parallel_0/mul" - op: "Mul" - input: "encoder/block_000/layer_000/SelfAttention/q/adafactor/scalar_add/parallel_0/add" - input: "encoder/block_000/layer_000/SelfAttention/q/adafactor/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/q/adafactor/add_1_1/parallel_0/Add" - op: "Add" - input: "encoder/block_000/layer_000/SelfAttention/q/adafactor/scalar_mul_1/parallel_0/mul" - input: "encoder/block_000/layer_000/SelfAttention/q/adafactor/scalar_mul_2/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/q/adafactor/rsqrt/parallel_0/Rsqrt" - op: "Rsqrt" - input: "encoder/block_000/layer_000/SelfAttention/q/adafactor/add_1_1/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/q/adafactor/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/Exit_4" - input: "encoder/block_000/layer_000/SelfAttention/q/adafactor/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/q/adafactor/square_2/parallel_0/Square" - op: "Square" - input: "encoder/block_000/layer_000/SelfAttention/q/adafactor/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/q/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/q/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - op: "Sum" - input: "encoder/block_000/layer_000/SelfAttention/q/adafactor/square_2/parallel_0/Square" - input: "encoder/block_000/layer_000/SelfAttention/q/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/q/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.000244140625 - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/q/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - op: "Mul" - input: "encoder/block_000/layer_000/SelfAttention/q/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - input: "encoder/block_000/layer_000/SelfAttention/q/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/q/adafactor/sqrt_1/parallel_0/Sqrt" - op: "Sqrt" - input: "encoder/block_000/layer_000/SelfAttention/q/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/q/adafactor/scalar_mul_3/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/q/adafactor/scalar_mul_3/parallel_0/mul" - op: "Mul" - input: "encoder/block_000/layer_000/SelfAttention/q/adafactor/sqrt_1/parallel_0/Sqrt" - input: "encoder/block_000/layer_000/SelfAttention/q/adafactor/scalar_mul_3/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/q/adafactor/add_2/parallel_0/Maximum" - op: "Maximum" - input: "encoder/block_000/layer_000/SelfAttention/q/adafactor/maximum_1/Const" - input: "encoder/block_000/layer_000/SelfAttention/q/adafactor/scalar_mul_3/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/q/adafactor/reciprocal/parallel_0/Reciprocal" - op: "Reciprocal" - input: "encoder/block_000/layer_000/SelfAttention/q/adafactor/add_2/parallel_0/Maximum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/q/adafactor/einsum_1/parallel_0/Mul" - op: "Mul" - input: "encoder/block_000/layer_000/SelfAttention/q/adafactor/einsum/parallel_0/Mul" - input: "encoder/block_000/layer_000/SelfAttention/q/adafactor/reciprocal/parallel_0/Reciprocal" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/q/adafactor/einsum_2/parallel_0/Mul" - op: "Mul" - input: "encoder/block_000/layer_000/SelfAttention/q/adafactor/einsum_1/parallel_0/Mul" - input: "encoder/block_000/layer_000/SelfAttention/q/adafactor/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/k_slot_v_1/group_deps" - op: "NoOp" -} -node { - name: "encoder/block_000/layer_000/SelfAttention/k_slot_v_1/group_deps_1" - op: "NoOp" -} -node { - name: "encoder/block_000/layer_000/SelfAttention/k/adafactor/square/parallel_0/Square" - op: "Square" - input: "while_loop/while/Exit_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/k/adafactor/scalar_add/parallel_0/add/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0000000031710769e-30 - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/k/adafactor/scalar_add/parallel_0/add" - op: "AddV2" - input: "encoder/block_000/layer_000/SelfAttention/k/adafactor/square/parallel_0/Square" - input: "encoder/block_000/layer_000/SelfAttention/k/adafactor/scalar_add/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/k/adafactor/square_1/parallel_0/Square" - op: "Square" - input: "encoder/block_000/layer_000/SelfAttention/k_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/k/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/k/adafactor/reduce_mean/reduce/parallel_0/Sum" - op: "Sum" - input: "encoder/block_000/layer_000/SelfAttention/k/adafactor/square_1/parallel_0/Square" - input: "encoder/block_000/layer_000/SelfAttention/k/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/k/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.000244140625 - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/k/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - op: "Mul" - input: "encoder/block_000/layer_000/SelfAttention/k/adafactor/reduce_mean/reduce/parallel_0/Sum" - input: "encoder/block_000/layer_000/SelfAttention/k/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/k/adafactor/sqrt/parallel_0/Sqrt" - op: "Sqrt" - input: "encoder/block_000/layer_000/SelfAttention/k/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/k/adafactor/add_1/parallel_0/Maximum" - op: "Maximum" - input: "encoder/block_000/layer_000/SelfAttention/k/adafactor/sqrt/parallel_0/Sqrt" - input: "encoder/block_000/layer_000/SelfAttention/k/adafactor/maximum/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/k/adafactor/scalar_mul/parallel_0/mul" - op: "Mul" - input: "encoder/block_000/layer_000/SelfAttention/k/adafactor/add_1/parallel_0/Maximum" - input: "mul_4" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/k/adafactor/scalar_mul_1/parallel_0/mul" - op: "Mul" - input: "encoder/block_000/layer_000/SelfAttention/k_slot_v/read" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/k/adafactor/scalar_mul_2/parallel_0/mul" - op: "Mul" - input: "encoder/block_000/layer_000/SelfAttention/k/adafactor/scalar_add/parallel_0/add" - input: "encoder/block_000/layer_000/SelfAttention/k/adafactor/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/k/adafactor/add_1_1/parallel_0/Add" - op: "Add" - input: "encoder/block_000/layer_000/SelfAttention/k/adafactor/scalar_mul_1/parallel_0/mul" - input: "encoder/block_000/layer_000/SelfAttention/k/adafactor/scalar_mul_2/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/k/adafactor/rsqrt/parallel_0/Rsqrt" - op: "Rsqrt" - input: "encoder/block_000/layer_000/SelfAttention/k/adafactor/add_1_1/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/k/adafactor/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/Exit_5" - input: "encoder/block_000/layer_000/SelfAttention/k/adafactor/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/k/adafactor/square_2/parallel_0/Square" - op: "Square" - input: "encoder/block_000/layer_000/SelfAttention/k/adafactor/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/k/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/k/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - op: "Sum" - input: "encoder/block_000/layer_000/SelfAttention/k/adafactor/square_2/parallel_0/Square" - input: "encoder/block_000/layer_000/SelfAttention/k/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/k/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.000244140625 - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/k/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - op: "Mul" - input: "encoder/block_000/layer_000/SelfAttention/k/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - input: "encoder/block_000/layer_000/SelfAttention/k/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/k/adafactor/sqrt_1/parallel_0/Sqrt" - op: "Sqrt" - input: "encoder/block_000/layer_000/SelfAttention/k/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/k/adafactor/scalar_mul_3/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/k/adafactor/scalar_mul_3/parallel_0/mul" - op: "Mul" - input: "encoder/block_000/layer_000/SelfAttention/k/adafactor/sqrt_1/parallel_0/Sqrt" - input: "encoder/block_000/layer_000/SelfAttention/k/adafactor/scalar_mul_3/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/k/adafactor/add_2/parallel_0/Maximum" - op: "Maximum" - input: "encoder/block_000/layer_000/SelfAttention/k/adafactor/maximum_1/Const" - input: "encoder/block_000/layer_000/SelfAttention/k/adafactor/scalar_mul_3/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/k/adafactor/reciprocal/parallel_0/Reciprocal" - op: "Reciprocal" - input: "encoder/block_000/layer_000/SelfAttention/k/adafactor/add_2/parallel_0/Maximum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/k/adafactor/einsum_1/parallel_0/Mul" - op: "Mul" - input: "encoder/block_000/layer_000/SelfAttention/k/adafactor/einsum/parallel_0/Mul" - input: "encoder/block_000/layer_000/SelfAttention/k/adafactor/reciprocal/parallel_0/Reciprocal" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/k/adafactor/einsum_2/parallel_0/Mul" - op: "Mul" - input: "encoder/block_000/layer_000/SelfAttention/k/adafactor/einsum_1/parallel_0/Mul" - input: "encoder/block_000/layer_000/SelfAttention/k/adafactor/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/v_slot_v_1/group_deps" - op: "NoOp" -} -node { - name: "encoder/block_000/layer_000/SelfAttention/v_slot_v_1/group_deps_1" - op: "NoOp" -} -node { - name: "encoder/block_000/layer_000/SelfAttention/v/adafactor/square/parallel_0/Square" - op: "Square" - input: "while_loop/while/Exit_6" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/v/adafactor/scalar_add/parallel_0/add/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0000000031710769e-30 - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/v/adafactor/scalar_add/parallel_0/add" - op: "AddV2" - input: "encoder/block_000/layer_000/SelfAttention/v/adafactor/square/parallel_0/Square" - input: "encoder/block_000/layer_000/SelfAttention/v/adafactor/scalar_add/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/v/adafactor/square_1/parallel_0/Square" - op: "Square" - input: "encoder/block_000/layer_000/SelfAttention/v_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/v/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/v/adafactor/reduce_mean/reduce/parallel_0/Sum" - op: "Sum" - input: "encoder/block_000/layer_000/SelfAttention/v/adafactor/square_1/parallel_0/Square" - input: "encoder/block_000/layer_000/SelfAttention/v/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/v/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.000244140625 - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/v/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - op: "Mul" - input: "encoder/block_000/layer_000/SelfAttention/v/adafactor/reduce_mean/reduce/parallel_0/Sum" - input: "encoder/block_000/layer_000/SelfAttention/v/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/v/adafactor/sqrt/parallel_0/Sqrt" - op: "Sqrt" - input: "encoder/block_000/layer_000/SelfAttention/v/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/v/adafactor/add_1/parallel_0/Maximum" - op: "Maximum" - input: "encoder/block_000/layer_000/SelfAttention/v/adafactor/sqrt/parallel_0/Sqrt" - input: "encoder/block_000/layer_000/SelfAttention/v/adafactor/maximum/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/v/adafactor/scalar_mul/parallel_0/mul" - op: "Mul" - input: "encoder/block_000/layer_000/SelfAttention/v/adafactor/add_1/parallel_0/Maximum" - input: "mul_4" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/v/adafactor/scalar_mul_1/parallel_0/mul" - op: "Mul" - input: "encoder/block_000/layer_000/SelfAttention/v_slot_v/read" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/v/adafactor/scalar_mul_2/parallel_0/mul" - op: "Mul" - input: "encoder/block_000/layer_000/SelfAttention/v/adafactor/scalar_add/parallel_0/add" - input: "encoder/block_000/layer_000/SelfAttention/v/adafactor/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/v/adafactor/add_1_1/parallel_0/Add" - op: "Add" - input: "encoder/block_000/layer_000/SelfAttention/v/adafactor/scalar_mul_1/parallel_0/mul" - input: "encoder/block_000/layer_000/SelfAttention/v/adafactor/scalar_mul_2/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/v/adafactor/rsqrt/parallel_0/Rsqrt" - op: "Rsqrt" - input: "encoder/block_000/layer_000/SelfAttention/v/adafactor/add_1_1/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/v/adafactor/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/Exit_6" - input: "encoder/block_000/layer_000/SelfAttention/v/adafactor/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/v/adafactor/square_2/parallel_0/Square" - op: "Square" - input: "encoder/block_000/layer_000/SelfAttention/v/adafactor/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/v/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/v/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - op: "Sum" - input: "encoder/block_000/layer_000/SelfAttention/v/adafactor/square_2/parallel_0/Square" - input: "encoder/block_000/layer_000/SelfAttention/v/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/v/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.000244140625 - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/v/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - op: "Mul" - input: "encoder/block_000/layer_000/SelfAttention/v/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - input: "encoder/block_000/layer_000/SelfAttention/v/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/v/adafactor/sqrt_1/parallel_0/Sqrt" - op: "Sqrt" - input: "encoder/block_000/layer_000/SelfAttention/v/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/v/adafactor/scalar_mul_3/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/v/adafactor/scalar_mul_3/parallel_0/mul" - op: "Mul" - input: "encoder/block_000/layer_000/SelfAttention/v/adafactor/sqrt_1/parallel_0/Sqrt" - input: "encoder/block_000/layer_000/SelfAttention/v/adafactor/scalar_mul_3/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/v/adafactor/add_2/parallel_0/Maximum" - op: "Maximum" - input: "encoder/block_000/layer_000/SelfAttention/v/adafactor/maximum_1/Const" - input: "encoder/block_000/layer_000/SelfAttention/v/adafactor/scalar_mul_3/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/v/adafactor/reciprocal/parallel_0/Reciprocal" - op: "Reciprocal" - input: "encoder/block_000/layer_000/SelfAttention/v/adafactor/add_2/parallel_0/Maximum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/v/adafactor/einsum_1/parallel_0/Mul" - op: "Mul" - input: "encoder/block_000/layer_000/SelfAttention/v/adafactor/einsum/parallel_0/Mul" - input: "encoder/block_000/layer_000/SelfAttention/v/adafactor/reciprocal/parallel_0/Reciprocal" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/v/adafactor/einsum_2/parallel_0/Mul" - op: "Mul" - input: "encoder/block_000/layer_000/SelfAttention/v/adafactor/einsum_1/parallel_0/Mul" - input: "encoder/block_000/layer_000/SelfAttention/v/adafactor/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/o_slot_v_1/group_deps" - op: "NoOp" -} -node { - name: "encoder/block_000/layer_000/SelfAttention/o_slot_v_1/group_deps_1" - op: "NoOp" -} -node { - name: "encoder/block_000/layer_000/SelfAttention/o/adafactor/square/parallel_0/Square" - op: "Square" - input: "while_loop/while/Exit_7" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/o/adafactor/scalar_add/parallel_0/add/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0000000031710769e-30 - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/o/adafactor/scalar_add/parallel_0/add" - op: "AddV2" - input: "encoder/block_000/layer_000/SelfAttention/o/adafactor/square/parallel_0/Square" - input: "encoder/block_000/layer_000/SelfAttention/o/adafactor/scalar_add/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/o/adafactor/square_1/parallel_0/Square" - op: "Square" - input: "encoder/block_000/layer_000/SelfAttention/o_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/o/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/o/adafactor/reduce_mean/reduce/parallel_0/Sum" - op: "Sum" - input: "encoder/block_000/layer_000/SelfAttention/o/adafactor/square_1/parallel_0/Square" - input: "encoder/block_000/layer_000/SelfAttention/o/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/o/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.000244140625 - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/o/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - op: "Mul" - input: "encoder/block_000/layer_000/SelfAttention/o/adafactor/reduce_mean/reduce/parallel_0/Sum" - input: "encoder/block_000/layer_000/SelfAttention/o/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/o/adafactor/sqrt/parallel_0/Sqrt" - op: "Sqrt" - input: "encoder/block_000/layer_000/SelfAttention/o/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/o/adafactor/add_1/parallel_0/Maximum" - op: "Maximum" - input: "encoder/block_000/layer_000/SelfAttention/o/adafactor/sqrt/parallel_0/Sqrt" - input: "encoder/block_000/layer_000/SelfAttention/o/adafactor/maximum/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/o/adafactor/scalar_mul/parallel_0/mul" - op: "Mul" - input: "encoder/block_000/layer_000/SelfAttention/o/adafactor/add_1/parallel_0/Maximum" - input: "mul_4" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/o/adafactor/scalar_mul_1/parallel_0/mul" - op: "Mul" - input: "encoder/block_000/layer_000/SelfAttention/o_slot_v/read" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/o/adafactor/scalar_mul_2/parallel_0/mul" - op: "Mul" - input: "encoder/block_000/layer_000/SelfAttention/o/adafactor/scalar_add/parallel_0/add" - input: "encoder/block_000/layer_000/SelfAttention/o/adafactor/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/o/adafactor/add_1_1/parallel_0/Add" - op: "Add" - input: "encoder/block_000/layer_000/SelfAttention/o/adafactor/scalar_mul_1/parallel_0/mul" - input: "encoder/block_000/layer_000/SelfAttention/o/adafactor/scalar_mul_2/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/o/adafactor/rsqrt/parallel_0/Rsqrt" - op: "Rsqrt" - input: "encoder/block_000/layer_000/SelfAttention/o/adafactor/add_1_1/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/o/adafactor/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/Exit_7" - input: "encoder/block_000/layer_000/SelfAttention/o/adafactor/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/o/adafactor/square_2/parallel_0/Square" - op: "Square" - input: "encoder/block_000/layer_000/SelfAttention/o/adafactor/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/o/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/o/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - op: "Sum" - input: "encoder/block_000/layer_000/SelfAttention/o/adafactor/square_2/parallel_0/Square" - input: "encoder/block_000/layer_000/SelfAttention/o/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/o/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.000244140625 - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/o/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - op: "Mul" - input: "encoder/block_000/layer_000/SelfAttention/o/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - input: "encoder/block_000/layer_000/SelfAttention/o/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/o/adafactor/sqrt_1/parallel_0/Sqrt" - op: "Sqrt" - input: "encoder/block_000/layer_000/SelfAttention/o/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/o/adafactor/scalar_mul_3/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/o/adafactor/scalar_mul_3/parallel_0/mul" - op: "Mul" - input: "encoder/block_000/layer_000/SelfAttention/o/adafactor/sqrt_1/parallel_0/Sqrt" - input: "encoder/block_000/layer_000/SelfAttention/o/adafactor/scalar_mul_3/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/o/adafactor/add_2/parallel_0/Maximum" - op: "Maximum" - input: "encoder/block_000/layer_000/SelfAttention/o/adafactor/maximum_1/Const" - input: "encoder/block_000/layer_000/SelfAttention/o/adafactor/scalar_mul_3/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/o/adafactor/reciprocal/parallel_0/Reciprocal" - op: "Reciprocal" - input: "encoder/block_000/layer_000/SelfAttention/o/adafactor/add_2/parallel_0/Maximum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/o/adafactor/einsum_1/parallel_0/Mul" - op: "Mul" - input: "encoder/block_000/layer_000/SelfAttention/o/adafactor/einsum/parallel_0/Mul" - input: "encoder/block_000/layer_000/SelfAttention/o/adafactor/reciprocal/parallel_0/Reciprocal" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/o/adafactor/einsum_2/parallel_0/Mul" - op: "Mul" - input: "encoder/block_000/layer_000/SelfAttention/o/adafactor/einsum_1/parallel_0/Mul" - input: "encoder/block_000/layer_000/SelfAttention/o/adafactor/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias_slot_v_1/group_deps" - op: "NoOp" -} -node { - name: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias_slot_v_1/group_deps_1" - op: "NoOp" -} -node { - name: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/square/parallel_0/Square" - op: "Square" - input: "while_loop/while/Exit_8" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/scalar_add/parallel_0/add/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0000000031710769e-30 - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/scalar_add/parallel_0/add" - op: "AddV2" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/square/parallel_0/Square" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/scalar_add/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/square_1/parallel_0/Square" - op: "Square" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/reduce_mean/reduce/parallel_0/Sum" - op: "Sum" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/square_1/parallel_0/Square" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.015625 - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - op: "Mul" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/reduce_mean/reduce/parallel_0/Sum" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/sqrt/parallel_0/Sqrt" - op: "Sqrt" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/add_1/parallel_0/Maximum" - op: "Maximum" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/sqrt/parallel_0/Sqrt" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/maximum/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/scalar_mul/parallel_0/mul" - op: "Mul" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/add_1/parallel_0/Maximum" - input: "mul_4" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/scalar_mul_1/parallel_0/mul" - op: "Mul" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias_slot_v/read" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/scalar_mul_2/parallel_0/mul" - op: "Mul" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/scalar_add/parallel_0/add" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/add_1_1/parallel_0/Add" - op: "Add" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/scalar_mul_1/parallel_0/mul" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/scalar_mul_2/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/rsqrt/parallel_0/Rsqrt" - op: "Rsqrt" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/add_1_1/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/Exit_8" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/square_2/parallel_0/Square" - op: "Square" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - op: "Sum" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/square_2/parallel_0/Square" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.015625 - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - op: "Mul" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/sqrt_1/parallel_0/Sqrt" - op: "Sqrt" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/scalar_mul_3/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/scalar_mul_3/parallel_0/mul" - op: "Mul" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/sqrt_1/parallel_0/Sqrt" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/scalar_mul_3/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/add_2/parallel_0/Maximum" - op: "Maximum" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/maximum_1/Const" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/scalar_mul_3/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/reciprocal/parallel_0/Reciprocal" - op: "Reciprocal" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/add_2/parallel_0/Maximum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/einsum_1/parallel_0/Mul" - op: "Mul" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/einsum/parallel_0/Mul" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/reciprocal/parallel_0/Reciprocal" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/einsum_2/parallel_0/Mul" - op: "Mul" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/einsum_1/parallel_0/Mul" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/rms_norm/scale_slot_v_1/group_deps" - op: "NoOp" -} -node { - name: "encoder/block_000/layer_001/rms_norm/scale_slot_v_1/group_deps_1" - op: "NoOp" -} -node { - name: "encoder/block_000/layer_001/rms_norm/scale/adafactor/square/parallel_0/Square" - op: "Square" - input: "while_loop/while/Exit_9" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/rms_norm/scale/adafactor/scalar_add/parallel_0/add/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0000000031710769e-30 - } - } - } -} -node { - name: "encoder/block_000/layer_001/rms_norm/scale/adafactor/scalar_add/parallel_0/add" - op: "AddV2" - input: "encoder/block_000/layer_001/rms_norm/scale/adafactor/square/parallel_0/Square" - input: "encoder/block_000/layer_001/rms_norm/scale/adafactor/scalar_add/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/rms_norm/scale/adafactor/square_1/parallel_0/Square" - op: "Square" - input: "encoder/block_000/layer_001/rms_norm/scale_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/rms_norm/scale/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "encoder/block_000/layer_001/rms_norm/scale/adafactor/reduce_mean/reduce/parallel_0/Sum" - op: "Sum" - input: "encoder/block_000/layer_001/rms_norm/scale/adafactor/square_1/parallel_0/Square" - input: "encoder/block_000/layer_001/rms_norm/scale/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "encoder/block_000/layer_001/rms_norm/scale/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.03125 - } - } - } -} -node { - name: "encoder/block_000/layer_001/rms_norm/scale/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - op: "Mul" - input: "encoder/block_000/layer_001/rms_norm/scale/adafactor/reduce_mean/reduce/parallel_0/Sum" - input: "encoder/block_000/layer_001/rms_norm/scale/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/rms_norm/scale/adafactor/sqrt/parallel_0/Sqrt" - op: "Sqrt" - input: "encoder/block_000/layer_001/rms_norm/scale/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/rms_norm/scale/adafactor/add_1/parallel_0/Maximum" - op: "Maximum" - input: "encoder/block_000/layer_001/rms_norm/scale/adafactor/sqrt/parallel_0/Sqrt" - input: "encoder/block_000/layer_001/rms_norm/scale/adafactor/maximum/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/rms_norm/scale/adafactor/scalar_mul/parallel_0/mul" - op: "Mul" - input: "encoder/block_000/layer_001/rms_norm/scale/adafactor/add_1/parallel_0/Maximum" - input: "mul_4" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/rms_norm/scale/adafactor/scalar_mul_1/parallel_0/mul" - op: "Mul" - input: "encoder/block_000/layer_001/rms_norm/scale_slot_v/read" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/rms_norm/scale/adafactor/scalar_mul_2/parallel_0/mul" - op: "Mul" - input: "encoder/block_000/layer_001/rms_norm/scale/adafactor/scalar_add/parallel_0/add" - input: "encoder/block_000/layer_001/rms_norm/scale/adafactor/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/rms_norm/scale/adafactor/add_1_1/parallel_0/Add" - op: "Add" - input: "encoder/block_000/layer_001/rms_norm/scale/adafactor/scalar_mul_1/parallel_0/mul" - input: "encoder/block_000/layer_001/rms_norm/scale/adafactor/scalar_mul_2/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/rms_norm/scale/adafactor/rsqrt/parallel_0/Rsqrt" - op: "Rsqrt" - input: "encoder/block_000/layer_001/rms_norm/scale/adafactor/add_1_1/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/rms_norm/scale/adafactor/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/Exit_9" - input: "encoder/block_000/layer_001/rms_norm/scale/adafactor/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/rms_norm/scale/adafactor/square_2/parallel_0/Square" - op: "Square" - input: "encoder/block_000/layer_001/rms_norm/scale/adafactor/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/rms_norm/scale/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "encoder/block_000/layer_001/rms_norm/scale/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - op: "Sum" - input: "encoder/block_000/layer_001/rms_norm/scale/adafactor/square_2/parallel_0/Square" - input: "encoder/block_000/layer_001/rms_norm/scale/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "encoder/block_000/layer_001/rms_norm/scale/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.03125 - } - } - } -} -node { - name: "encoder/block_000/layer_001/rms_norm/scale/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - op: "Mul" - input: "encoder/block_000/layer_001/rms_norm/scale/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - input: "encoder/block_000/layer_001/rms_norm/scale/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/rms_norm/scale/adafactor/sqrt_1/parallel_0/Sqrt" - op: "Sqrt" - input: "encoder/block_000/layer_001/rms_norm/scale/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/rms_norm/scale/adafactor/scalar_mul_3/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "encoder/block_000/layer_001/rms_norm/scale/adafactor/scalar_mul_3/parallel_0/mul" - op: "Mul" - input: "encoder/block_000/layer_001/rms_norm/scale/adafactor/sqrt_1/parallel_0/Sqrt" - input: "encoder/block_000/layer_001/rms_norm/scale/adafactor/scalar_mul_3/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/rms_norm/scale/adafactor/add_2/parallel_0/Maximum" - op: "Maximum" - input: "encoder/block_000/layer_001/rms_norm/scale/adafactor/maximum_1/Const" - input: "encoder/block_000/layer_001/rms_norm/scale/adafactor/scalar_mul_3/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/rms_norm/scale/adafactor/reciprocal/parallel_0/Reciprocal" - op: "Reciprocal" - input: "encoder/block_000/layer_001/rms_norm/scale/adafactor/add_2/parallel_0/Maximum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/rms_norm/scale/adafactor/einsum_1/parallel_0/Mul" - op: "Mul" - input: "encoder/block_000/layer_001/rms_norm/scale/adafactor/einsum/parallel_0/Mul" - input: "encoder/block_000/layer_001/rms_norm/scale/adafactor/reciprocal/parallel_0/Reciprocal" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/rms_norm/scale/adafactor/einsum_2/parallel_0/Mul" - op: "Mul" - input: "encoder/block_000/layer_001/rms_norm/scale/adafactor/einsum_1/parallel_0/Mul" - input: "encoder/block_000/layer_001/rms_norm/scale/adafactor/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slot_v_1/group_deps" - op: "NoOp" -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slot_v_1/group_deps_1" - op: "NoOp" -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/adafactor/square/parallel_0/Square" - op: "Square" - input: "while_loop/while/Exit_10" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/adafactor/scalar_add/parallel_0/add/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0000000031710769e-30 - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/adafactor/scalar_add/parallel_0/add" - op: "AddV2" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/adafactor/square/parallel_0/Square" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/adafactor/scalar_add/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/adafactor/square_1/parallel_0/Square" - op: "Square" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/adafactor/reduce_mean/reduce/parallel_0/Sum" - op: "Sum" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/adafactor/square_1/parallel_0/Square" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.00048828125 - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - op: "Mul" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/adafactor/reduce_mean/reduce/parallel_0/Sum" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/adafactor/sqrt/parallel_0/Sqrt" - op: "Sqrt" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/adafactor/add_1/parallel_0/Maximum" - op: "Maximum" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/adafactor/sqrt/parallel_0/Sqrt" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/adafactor/maximum/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/adafactor/scalar_mul/parallel_0/mul" - op: "Mul" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/adafactor/add_1/parallel_0/Maximum" - input: "mul_4" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/adafactor/scalar_mul_1/parallel_0/mul" - op: "Mul" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slot_v/read" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/adafactor/scalar_mul_2/parallel_0/mul" - op: "Mul" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/adafactor/scalar_add/parallel_0/add" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/adafactor/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/adafactor/add_1_1/parallel_0/Add" - op: "Add" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/adafactor/scalar_mul_1/parallel_0/mul" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/adafactor/scalar_mul_2/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/adafactor/rsqrt/parallel_0/Rsqrt" - op: "Rsqrt" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/adafactor/add_1_1/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/adafactor/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/Exit_10" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/adafactor/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/adafactor/square_2/parallel_0/Square" - op: "Square" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/adafactor/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - op: "Sum" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/adafactor/square_2/parallel_0/Square" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.00048828125 - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - op: "Mul" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/adafactor/sqrt_1/parallel_0/Sqrt" - op: "Sqrt" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/adafactor/scalar_mul_3/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/adafactor/scalar_mul_3/parallel_0/mul" - op: "Mul" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/adafactor/sqrt_1/parallel_0/Sqrt" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/adafactor/scalar_mul_3/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/adafactor/add_2/parallel_0/Maximum" - op: "Maximum" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/adafactor/maximum_1/Const" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/adafactor/scalar_mul_3/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/adafactor/reciprocal/parallel_0/Reciprocal" - op: "Reciprocal" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/adafactor/add_2/parallel_0/Maximum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/adafactor/einsum_1/parallel_0/Mul" - op: "Mul" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/adafactor/einsum/parallel_0/Mul" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/adafactor/reciprocal/parallel_0/Reciprocal" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/adafactor/einsum_2/parallel_0/Mul" - op: "Mul" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/adafactor/einsum_1/parallel_0/Mul" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/adafactor/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slot_v_1/group_deps" - op: "NoOp" -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slot_v_1/group_deps_1" - op: "NoOp" -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/adafactor/square/parallel_0/Square" - op: "Square" - input: "while_loop/while/Exit_11" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/adafactor/scalar_add/parallel_0/add/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0000000031710769e-30 - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/adafactor/scalar_add/parallel_0/add" - op: "AddV2" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/adafactor/square/parallel_0/Square" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/adafactor/scalar_add/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/adafactor/square_1/parallel_0/Square" - op: "Square" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/adafactor/reduce_mean/reduce/parallel_0/Sum" - op: "Sum" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/adafactor/square_1/parallel_0/Square" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.00048828125 - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - op: "Mul" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/adafactor/reduce_mean/reduce/parallel_0/Sum" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/adafactor/sqrt/parallel_0/Sqrt" - op: "Sqrt" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/adafactor/add_1/parallel_0/Maximum" - op: "Maximum" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/adafactor/sqrt/parallel_0/Sqrt" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/adafactor/maximum/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/adafactor/scalar_mul/parallel_0/mul" - op: "Mul" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/adafactor/add_1/parallel_0/Maximum" - input: "mul_4" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/adafactor/scalar_mul_1/parallel_0/mul" - op: "Mul" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slot_v/read" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/adafactor/scalar_mul_2/parallel_0/mul" - op: "Mul" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/adafactor/scalar_add/parallel_0/add" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/adafactor/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/adafactor/add_1_1/parallel_0/Add" - op: "Add" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/adafactor/scalar_mul_1/parallel_0/mul" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/adafactor/scalar_mul_2/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/adafactor/rsqrt/parallel_0/Rsqrt" - op: "Rsqrt" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/adafactor/add_1_1/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/adafactor/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/Exit_11" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/adafactor/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/adafactor/square_2/parallel_0/Square" - op: "Square" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/adafactor/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - op: "Sum" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/adafactor/square_2/parallel_0/Square" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.00048828125 - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - op: "Mul" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/adafactor/sqrt_1/parallel_0/Sqrt" - op: "Sqrt" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/adafactor/scalar_mul_3/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/adafactor/scalar_mul_3/parallel_0/mul" - op: "Mul" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/adafactor/sqrt_1/parallel_0/Sqrt" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/adafactor/scalar_mul_3/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/adafactor/add_2/parallel_0/Maximum" - op: "Maximum" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/adafactor/maximum_1/Const" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/adafactor/scalar_mul_3/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/adafactor/reciprocal/parallel_0/Reciprocal" - op: "Reciprocal" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/adafactor/add_2/parallel_0/Maximum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/adafactor/einsum_1/parallel_0/Mul" - op: "Mul" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/adafactor/einsum/parallel_0/Mul" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/adafactor/reciprocal/parallel_0/Reciprocal" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/adafactor/einsum_2/parallel_0/Mul" - op: "Mul" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/adafactor/einsum_1/parallel_0/Mul" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/adafactor/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wo/kernel_slot_v_1/group_deps" - op: "NoOp" -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wo/kernel_slot_v_1/group_deps_1" - op: "NoOp" -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/adafactor/square/parallel_0/Square" - op: "Square" - input: "while_loop/while/Exit_12" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/adafactor/scalar_add/parallel_0/add/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0000000031710769e-30 - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/adafactor/scalar_add/parallel_0/add" - op: "AddV2" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/adafactor/square/parallel_0/Square" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/adafactor/scalar_add/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/adafactor/square_1/parallel_0/Square" - op: "Square" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/adafactor/reduce_mean/reduce/parallel_0/Sum" - op: "Sum" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/adafactor/square_1/parallel_0/Square" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.00048828125 - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - op: "Mul" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/adafactor/reduce_mean/reduce/parallel_0/Sum" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/adafactor/sqrt/parallel_0/Sqrt" - op: "Sqrt" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/adafactor/add_1/parallel_0/Maximum" - op: "Maximum" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/adafactor/sqrt/parallel_0/Sqrt" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/adafactor/maximum/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/adafactor/scalar_mul/parallel_0/mul" - op: "Mul" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/adafactor/add_1/parallel_0/Maximum" - input: "mul_4" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/adafactor/scalar_mul_1/parallel_0/mul" - op: "Mul" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel_slot_v/read" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/adafactor/scalar_mul_2/parallel_0/mul" - op: "Mul" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/adafactor/scalar_add/parallel_0/add" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/adafactor/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/adafactor/add_1_1/parallel_0/Add" - op: "Add" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/adafactor/scalar_mul_1/parallel_0/mul" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/adafactor/scalar_mul_2/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/adafactor/rsqrt/parallel_0/Rsqrt" - op: "Rsqrt" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/adafactor/add_1_1/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/adafactor/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/Exit_12" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/adafactor/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/adafactor/square_2/parallel_0/Square" - op: "Square" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/adafactor/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - op: "Sum" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/adafactor/square_2/parallel_0/Square" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.00048828125 - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - op: "Mul" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/adafactor/sqrt_1/parallel_0/Sqrt" - op: "Sqrt" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/adafactor/scalar_mul_3/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/adafactor/scalar_mul_3/parallel_0/mul" - op: "Mul" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/adafactor/sqrt_1/parallel_0/Sqrt" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/adafactor/scalar_mul_3/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/adafactor/add_2/parallel_0/Maximum" - op: "Maximum" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/adafactor/maximum_1/Const" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/adafactor/scalar_mul_3/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/adafactor/reciprocal/parallel_0/Reciprocal" - op: "Reciprocal" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/adafactor/add_2/parallel_0/Maximum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/adafactor/einsum_1/parallel_0/Mul" - op: "Mul" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/adafactor/einsum/parallel_0/Mul" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/adafactor/reciprocal/parallel_0/Reciprocal" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/adafactor/einsum_2/parallel_0/Mul" - op: "Mul" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/adafactor/einsum_1/parallel_0/Mul" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/adafactor/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/rms_norm/scale_slot_v_1/group_deps" - op: "NoOp" -} -node { - name: "encoder/block_001/layer_000/rms_norm/scale_slot_v_1/group_deps_1" - op: "NoOp" -} -node { - name: "encoder/block_001/layer_000/rms_norm/scale/adafactor/square/parallel_0/Square" - op: "Square" - input: "while_loop/while/Exit_13" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/rms_norm/scale/adafactor/scalar_add/parallel_0/add/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0000000031710769e-30 - } - } - } -} -node { - name: "encoder/block_001/layer_000/rms_norm/scale/adafactor/scalar_add/parallel_0/add" - op: "AddV2" - input: "encoder/block_001/layer_000/rms_norm/scale/adafactor/square/parallel_0/Square" - input: "encoder/block_001/layer_000/rms_norm/scale/adafactor/scalar_add/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/rms_norm/scale/adafactor/square_1/parallel_0/Square" - op: "Square" - input: "encoder/block_001/layer_000/rms_norm/scale_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/rms_norm/scale/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "encoder/block_001/layer_000/rms_norm/scale/adafactor/reduce_mean/reduce/parallel_0/Sum" - op: "Sum" - input: "encoder/block_001/layer_000/rms_norm/scale/adafactor/square_1/parallel_0/Square" - input: "encoder/block_001/layer_000/rms_norm/scale/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "encoder/block_001/layer_000/rms_norm/scale/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.03125 - } - } - } -} -node { - name: "encoder/block_001/layer_000/rms_norm/scale/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - op: "Mul" - input: "encoder/block_001/layer_000/rms_norm/scale/adafactor/reduce_mean/reduce/parallel_0/Sum" - input: "encoder/block_001/layer_000/rms_norm/scale/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/rms_norm/scale/adafactor/sqrt/parallel_0/Sqrt" - op: "Sqrt" - input: "encoder/block_001/layer_000/rms_norm/scale/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/rms_norm/scale/adafactor/add_1/parallel_0/Maximum" - op: "Maximum" - input: "encoder/block_001/layer_000/rms_norm/scale/adafactor/sqrt/parallel_0/Sqrt" - input: "encoder/block_001/layer_000/rms_norm/scale/adafactor/maximum/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/rms_norm/scale/adafactor/scalar_mul/parallel_0/mul" - op: "Mul" - input: "encoder/block_001/layer_000/rms_norm/scale/adafactor/add_1/parallel_0/Maximum" - input: "mul_4" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/rms_norm/scale/adafactor/scalar_mul_1/parallel_0/mul" - op: "Mul" - input: "encoder/block_001/layer_000/rms_norm/scale_slot_v/read" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/rms_norm/scale/adafactor/scalar_mul_2/parallel_0/mul" - op: "Mul" - input: "encoder/block_001/layer_000/rms_norm/scale/adafactor/scalar_add/parallel_0/add" - input: "encoder/block_001/layer_000/rms_norm/scale/adafactor/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/rms_norm/scale/adafactor/add_1_1/parallel_0/Add" - op: "Add" - input: "encoder/block_001/layer_000/rms_norm/scale/adafactor/scalar_mul_1/parallel_0/mul" - input: "encoder/block_001/layer_000/rms_norm/scale/adafactor/scalar_mul_2/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/rms_norm/scale/adafactor/rsqrt/parallel_0/Rsqrt" - op: "Rsqrt" - input: "encoder/block_001/layer_000/rms_norm/scale/adafactor/add_1_1/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/rms_norm/scale/adafactor/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/Exit_13" - input: "encoder/block_001/layer_000/rms_norm/scale/adafactor/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/rms_norm/scale/adafactor/square_2/parallel_0/Square" - op: "Square" - input: "encoder/block_001/layer_000/rms_norm/scale/adafactor/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/rms_norm/scale/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "encoder/block_001/layer_000/rms_norm/scale/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - op: "Sum" - input: "encoder/block_001/layer_000/rms_norm/scale/adafactor/square_2/parallel_0/Square" - input: "encoder/block_001/layer_000/rms_norm/scale/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "encoder/block_001/layer_000/rms_norm/scale/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.03125 - } - } - } -} -node { - name: "encoder/block_001/layer_000/rms_norm/scale/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - op: "Mul" - input: "encoder/block_001/layer_000/rms_norm/scale/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - input: "encoder/block_001/layer_000/rms_norm/scale/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/rms_norm/scale/adafactor/sqrt_1/parallel_0/Sqrt" - op: "Sqrt" - input: "encoder/block_001/layer_000/rms_norm/scale/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/rms_norm/scale/adafactor/scalar_mul_3/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "encoder/block_001/layer_000/rms_norm/scale/adafactor/scalar_mul_3/parallel_0/mul" - op: "Mul" - input: "encoder/block_001/layer_000/rms_norm/scale/adafactor/sqrt_1/parallel_0/Sqrt" - input: "encoder/block_001/layer_000/rms_norm/scale/adafactor/scalar_mul_3/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/rms_norm/scale/adafactor/add_2/parallel_0/Maximum" - op: "Maximum" - input: "encoder/block_001/layer_000/rms_norm/scale/adafactor/maximum_1/Const" - input: "encoder/block_001/layer_000/rms_norm/scale/adafactor/scalar_mul_3/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/rms_norm/scale/adafactor/reciprocal/parallel_0/Reciprocal" - op: "Reciprocal" - input: "encoder/block_001/layer_000/rms_norm/scale/adafactor/add_2/parallel_0/Maximum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/rms_norm/scale/adafactor/einsum_1/parallel_0/Mul" - op: "Mul" - input: "encoder/block_001/layer_000/rms_norm/scale/adafactor/einsum/parallel_0/Mul" - input: "encoder/block_001/layer_000/rms_norm/scale/adafactor/reciprocal/parallel_0/Reciprocal" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/rms_norm/scale/adafactor/einsum_2/parallel_0/Mul" - op: "Mul" - input: "encoder/block_001/layer_000/rms_norm/scale/adafactor/einsum_1/parallel_0/Mul" - input: "encoder/block_001/layer_000/rms_norm/scale/adafactor/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/q_slot_v_1/group_deps" - op: "NoOp" -} -node { - name: "encoder/block_001/layer_000/SelfAttention/q_slot_v_1/group_deps_1" - op: "NoOp" -} -node { - name: "encoder/block_001/layer_000/SelfAttention/q/adafactor/square/parallel_0/Square" - op: "Square" - input: "while_loop/while/Exit_14" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/q/adafactor/scalar_add/parallel_0/add/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0000000031710769e-30 - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/q/adafactor/scalar_add/parallel_0/add" - op: "AddV2" - input: "encoder/block_001/layer_000/SelfAttention/q/adafactor/square/parallel_0/Square" - input: "encoder/block_001/layer_000/SelfAttention/q/adafactor/scalar_add/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/q/adafactor/square_1/parallel_0/Square" - op: "Square" - input: "encoder/block_001/layer_000/SelfAttention/q_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/q/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/q/adafactor/reduce_mean/reduce/parallel_0/Sum" - op: "Sum" - input: "encoder/block_001/layer_000/SelfAttention/q/adafactor/square_1/parallel_0/Square" - input: "encoder/block_001/layer_000/SelfAttention/q/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/q/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.000244140625 - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/q/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - op: "Mul" - input: "encoder/block_001/layer_000/SelfAttention/q/adafactor/reduce_mean/reduce/parallel_0/Sum" - input: "encoder/block_001/layer_000/SelfAttention/q/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/q/adafactor/sqrt/parallel_0/Sqrt" - op: "Sqrt" - input: "encoder/block_001/layer_000/SelfAttention/q/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/q/adafactor/add_1/parallel_0/Maximum" - op: "Maximum" - input: "encoder/block_001/layer_000/SelfAttention/q/adafactor/sqrt/parallel_0/Sqrt" - input: "encoder/block_001/layer_000/SelfAttention/q/adafactor/maximum/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/q/adafactor/scalar_mul/parallel_0/mul" - op: "Mul" - input: "encoder/block_001/layer_000/SelfAttention/q/adafactor/add_1/parallel_0/Maximum" - input: "mul_4" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/q/adafactor/scalar_mul_1/parallel_0/mul" - op: "Mul" - input: "encoder/block_001/layer_000/SelfAttention/q_slot_v/read" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/q/adafactor/scalar_mul_2/parallel_0/mul" - op: "Mul" - input: "encoder/block_001/layer_000/SelfAttention/q/adafactor/scalar_add/parallel_0/add" - input: "encoder/block_001/layer_000/SelfAttention/q/adafactor/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/q/adafactor/add_1_1/parallel_0/Add" - op: "Add" - input: "encoder/block_001/layer_000/SelfAttention/q/adafactor/scalar_mul_1/parallel_0/mul" - input: "encoder/block_001/layer_000/SelfAttention/q/adafactor/scalar_mul_2/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/q/adafactor/rsqrt/parallel_0/Rsqrt" - op: "Rsqrt" - input: "encoder/block_001/layer_000/SelfAttention/q/adafactor/add_1_1/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/q/adafactor/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/Exit_14" - input: "encoder/block_001/layer_000/SelfAttention/q/adafactor/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/q/adafactor/square_2/parallel_0/Square" - op: "Square" - input: "encoder/block_001/layer_000/SelfAttention/q/adafactor/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/q/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/q/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - op: "Sum" - input: "encoder/block_001/layer_000/SelfAttention/q/adafactor/square_2/parallel_0/Square" - input: "encoder/block_001/layer_000/SelfAttention/q/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/q/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.000244140625 - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/q/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - op: "Mul" - input: "encoder/block_001/layer_000/SelfAttention/q/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - input: "encoder/block_001/layer_000/SelfAttention/q/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/q/adafactor/sqrt_1/parallel_0/Sqrt" - op: "Sqrt" - input: "encoder/block_001/layer_000/SelfAttention/q/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/q/adafactor/scalar_mul_3/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/q/adafactor/scalar_mul_3/parallel_0/mul" - op: "Mul" - input: "encoder/block_001/layer_000/SelfAttention/q/adafactor/sqrt_1/parallel_0/Sqrt" - input: "encoder/block_001/layer_000/SelfAttention/q/adafactor/scalar_mul_3/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/q/adafactor/add_2/parallel_0/Maximum" - op: "Maximum" - input: "encoder/block_001/layer_000/SelfAttention/q/adafactor/maximum_1/Const" - input: "encoder/block_001/layer_000/SelfAttention/q/adafactor/scalar_mul_3/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/q/adafactor/reciprocal/parallel_0/Reciprocal" - op: "Reciprocal" - input: "encoder/block_001/layer_000/SelfAttention/q/adafactor/add_2/parallel_0/Maximum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/q/adafactor/einsum_1/parallel_0/Mul" - op: "Mul" - input: "encoder/block_001/layer_000/SelfAttention/q/adafactor/einsum/parallel_0/Mul" - input: "encoder/block_001/layer_000/SelfAttention/q/adafactor/reciprocal/parallel_0/Reciprocal" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/q/adafactor/einsum_2/parallel_0/Mul" - op: "Mul" - input: "encoder/block_001/layer_000/SelfAttention/q/adafactor/einsum_1/parallel_0/Mul" - input: "encoder/block_001/layer_000/SelfAttention/q/adafactor/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/k_slot_v_1/group_deps" - op: "NoOp" -} -node { - name: "encoder/block_001/layer_000/SelfAttention/k_slot_v_1/group_deps_1" - op: "NoOp" -} -node { - name: "encoder/block_001/layer_000/SelfAttention/k/adafactor/square/parallel_0/Square" - op: "Square" - input: "while_loop/while/Exit_15" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/k/adafactor/scalar_add/parallel_0/add/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0000000031710769e-30 - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/k/adafactor/scalar_add/parallel_0/add" - op: "AddV2" - input: "encoder/block_001/layer_000/SelfAttention/k/adafactor/square/parallel_0/Square" - input: "encoder/block_001/layer_000/SelfAttention/k/adafactor/scalar_add/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/k/adafactor/square_1/parallel_0/Square" - op: "Square" - input: "encoder/block_001/layer_000/SelfAttention/k_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/k/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/k/adafactor/reduce_mean/reduce/parallel_0/Sum" - op: "Sum" - input: "encoder/block_001/layer_000/SelfAttention/k/adafactor/square_1/parallel_0/Square" - input: "encoder/block_001/layer_000/SelfAttention/k/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/k/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.000244140625 - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/k/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - op: "Mul" - input: "encoder/block_001/layer_000/SelfAttention/k/adafactor/reduce_mean/reduce/parallel_0/Sum" - input: "encoder/block_001/layer_000/SelfAttention/k/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/k/adafactor/sqrt/parallel_0/Sqrt" - op: "Sqrt" - input: "encoder/block_001/layer_000/SelfAttention/k/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/k/adafactor/add_1/parallel_0/Maximum" - op: "Maximum" - input: "encoder/block_001/layer_000/SelfAttention/k/adafactor/sqrt/parallel_0/Sqrt" - input: "encoder/block_001/layer_000/SelfAttention/k/adafactor/maximum/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/k/adafactor/scalar_mul/parallel_0/mul" - op: "Mul" - input: "encoder/block_001/layer_000/SelfAttention/k/adafactor/add_1/parallel_0/Maximum" - input: "mul_4" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/k/adafactor/scalar_mul_1/parallel_0/mul" - op: "Mul" - input: "encoder/block_001/layer_000/SelfAttention/k_slot_v/read" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/k/adafactor/scalar_mul_2/parallel_0/mul" - op: "Mul" - input: "encoder/block_001/layer_000/SelfAttention/k/adafactor/scalar_add/parallel_0/add" - input: "encoder/block_001/layer_000/SelfAttention/k/adafactor/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/k/adafactor/add_1_1/parallel_0/Add" - op: "Add" - input: "encoder/block_001/layer_000/SelfAttention/k/adafactor/scalar_mul_1/parallel_0/mul" - input: "encoder/block_001/layer_000/SelfAttention/k/adafactor/scalar_mul_2/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/k/adafactor/rsqrt/parallel_0/Rsqrt" - op: "Rsqrt" - input: "encoder/block_001/layer_000/SelfAttention/k/adafactor/add_1_1/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/k/adafactor/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/Exit_15" - input: "encoder/block_001/layer_000/SelfAttention/k/adafactor/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/k/adafactor/square_2/parallel_0/Square" - op: "Square" - input: "encoder/block_001/layer_000/SelfAttention/k/adafactor/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/k/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/k/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - op: "Sum" - input: "encoder/block_001/layer_000/SelfAttention/k/adafactor/square_2/parallel_0/Square" - input: "encoder/block_001/layer_000/SelfAttention/k/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/k/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.000244140625 - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/k/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - op: "Mul" - input: "encoder/block_001/layer_000/SelfAttention/k/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - input: "encoder/block_001/layer_000/SelfAttention/k/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/k/adafactor/sqrt_1/parallel_0/Sqrt" - op: "Sqrt" - input: "encoder/block_001/layer_000/SelfAttention/k/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/k/adafactor/scalar_mul_3/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/k/adafactor/scalar_mul_3/parallel_0/mul" - op: "Mul" - input: "encoder/block_001/layer_000/SelfAttention/k/adafactor/sqrt_1/parallel_0/Sqrt" - input: "encoder/block_001/layer_000/SelfAttention/k/adafactor/scalar_mul_3/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/k/adafactor/add_2/parallel_0/Maximum" - op: "Maximum" - input: "encoder/block_001/layer_000/SelfAttention/k/adafactor/maximum_1/Const" - input: "encoder/block_001/layer_000/SelfAttention/k/adafactor/scalar_mul_3/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/k/adafactor/reciprocal/parallel_0/Reciprocal" - op: "Reciprocal" - input: "encoder/block_001/layer_000/SelfAttention/k/adafactor/add_2/parallel_0/Maximum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/k/adafactor/einsum_1/parallel_0/Mul" - op: "Mul" - input: "encoder/block_001/layer_000/SelfAttention/k/adafactor/einsum/parallel_0/Mul" - input: "encoder/block_001/layer_000/SelfAttention/k/adafactor/reciprocal/parallel_0/Reciprocal" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/k/adafactor/einsum_2/parallel_0/Mul" - op: "Mul" - input: "encoder/block_001/layer_000/SelfAttention/k/adafactor/einsum_1/parallel_0/Mul" - input: "encoder/block_001/layer_000/SelfAttention/k/adafactor/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/v_slot_v_1/group_deps" - op: "NoOp" -} -node { - name: "encoder/block_001/layer_000/SelfAttention/v_slot_v_1/group_deps_1" - op: "NoOp" -} -node { - name: "encoder/block_001/layer_000/SelfAttention/v/adafactor/square/parallel_0/Square" - op: "Square" - input: "while_loop/while/Exit_16" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/v/adafactor/scalar_add/parallel_0/add/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0000000031710769e-30 - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/v/adafactor/scalar_add/parallel_0/add" - op: "AddV2" - input: "encoder/block_001/layer_000/SelfAttention/v/adafactor/square/parallel_0/Square" - input: "encoder/block_001/layer_000/SelfAttention/v/adafactor/scalar_add/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/v/adafactor/square_1/parallel_0/Square" - op: "Square" - input: "encoder/block_001/layer_000/SelfAttention/v_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/v/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/v/adafactor/reduce_mean/reduce/parallel_0/Sum" - op: "Sum" - input: "encoder/block_001/layer_000/SelfAttention/v/adafactor/square_1/parallel_0/Square" - input: "encoder/block_001/layer_000/SelfAttention/v/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/v/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.000244140625 - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/v/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - op: "Mul" - input: "encoder/block_001/layer_000/SelfAttention/v/adafactor/reduce_mean/reduce/parallel_0/Sum" - input: "encoder/block_001/layer_000/SelfAttention/v/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/v/adafactor/sqrt/parallel_0/Sqrt" - op: "Sqrt" - input: "encoder/block_001/layer_000/SelfAttention/v/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/v/adafactor/add_1/parallel_0/Maximum" - op: "Maximum" - input: "encoder/block_001/layer_000/SelfAttention/v/adafactor/sqrt/parallel_0/Sqrt" - input: "encoder/block_001/layer_000/SelfAttention/v/adafactor/maximum/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/v/adafactor/scalar_mul/parallel_0/mul" - op: "Mul" - input: "encoder/block_001/layer_000/SelfAttention/v/adafactor/add_1/parallel_0/Maximum" - input: "mul_4" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/v/adafactor/scalar_mul_1/parallel_0/mul" - op: "Mul" - input: "encoder/block_001/layer_000/SelfAttention/v_slot_v/read" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/v/adafactor/scalar_mul_2/parallel_0/mul" - op: "Mul" - input: "encoder/block_001/layer_000/SelfAttention/v/adafactor/scalar_add/parallel_0/add" - input: "encoder/block_001/layer_000/SelfAttention/v/adafactor/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/v/adafactor/add_1_1/parallel_0/Add" - op: "Add" - input: "encoder/block_001/layer_000/SelfAttention/v/adafactor/scalar_mul_1/parallel_0/mul" - input: "encoder/block_001/layer_000/SelfAttention/v/adafactor/scalar_mul_2/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/v/adafactor/rsqrt/parallel_0/Rsqrt" - op: "Rsqrt" - input: "encoder/block_001/layer_000/SelfAttention/v/adafactor/add_1_1/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/v/adafactor/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/Exit_16" - input: "encoder/block_001/layer_000/SelfAttention/v/adafactor/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/v/adafactor/square_2/parallel_0/Square" - op: "Square" - input: "encoder/block_001/layer_000/SelfAttention/v/adafactor/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/v/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/v/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - op: "Sum" - input: "encoder/block_001/layer_000/SelfAttention/v/adafactor/square_2/parallel_0/Square" - input: "encoder/block_001/layer_000/SelfAttention/v/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/v/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.000244140625 - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/v/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - op: "Mul" - input: "encoder/block_001/layer_000/SelfAttention/v/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - input: "encoder/block_001/layer_000/SelfAttention/v/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/v/adafactor/sqrt_1/parallel_0/Sqrt" - op: "Sqrt" - input: "encoder/block_001/layer_000/SelfAttention/v/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/v/adafactor/scalar_mul_3/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/v/adafactor/scalar_mul_3/parallel_0/mul" - op: "Mul" - input: "encoder/block_001/layer_000/SelfAttention/v/adafactor/sqrt_1/parallel_0/Sqrt" - input: "encoder/block_001/layer_000/SelfAttention/v/adafactor/scalar_mul_3/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/v/adafactor/add_2/parallel_0/Maximum" - op: "Maximum" - input: "encoder/block_001/layer_000/SelfAttention/v/adafactor/maximum_1/Const" - input: "encoder/block_001/layer_000/SelfAttention/v/adafactor/scalar_mul_3/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/v/adafactor/reciprocal/parallel_0/Reciprocal" - op: "Reciprocal" - input: "encoder/block_001/layer_000/SelfAttention/v/adafactor/add_2/parallel_0/Maximum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/v/adafactor/einsum_1/parallel_0/Mul" - op: "Mul" - input: "encoder/block_001/layer_000/SelfAttention/v/adafactor/einsum/parallel_0/Mul" - input: "encoder/block_001/layer_000/SelfAttention/v/adafactor/reciprocal/parallel_0/Reciprocal" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/v/adafactor/einsum_2/parallel_0/Mul" - op: "Mul" - input: "encoder/block_001/layer_000/SelfAttention/v/adafactor/einsum_1/parallel_0/Mul" - input: "encoder/block_001/layer_000/SelfAttention/v/adafactor/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/o_slot_v_1/group_deps" - op: "NoOp" -} -node { - name: "encoder/block_001/layer_000/SelfAttention/o_slot_v_1/group_deps_1" - op: "NoOp" -} -node { - name: "encoder/block_001/layer_000/SelfAttention/o/adafactor/square/parallel_0/Square" - op: "Square" - input: "while_loop/while/Exit_17" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/o/adafactor/scalar_add/parallel_0/add/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0000000031710769e-30 - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/o/adafactor/scalar_add/parallel_0/add" - op: "AddV2" - input: "encoder/block_001/layer_000/SelfAttention/o/adafactor/square/parallel_0/Square" - input: "encoder/block_001/layer_000/SelfAttention/o/adafactor/scalar_add/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/o/adafactor/square_1/parallel_0/Square" - op: "Square" - input: "encoder/block_001/layer_000/SelfAttention/o_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/o/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/o/adafactor/reduce_mean/reduce/parallel_0/Sum" - op: "Sum" - input: "encoder/block_001/layer_000/SelfAttention/o/adafactor/square_1/parallel_0/Square" - input: "encoder/block_001/layer_000/SelfAttention/o/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/o/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.000244140625 - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/o/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - op: "Mul" - input: "encoder/block_001/layer_000/SelfAttention/o/adafactor/reduce_mean/reduce/parallel_0/Sum" - input: "encoder/block_001/layer_000/SelfAttention/o/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/o/adafactor/sqrt/parallel_0/Sqrt" - op: "Sqrt" - input: "encoder/block_001/layer_000/SelfAttention/o/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/o/adafactor/add_1/parallel_0/Maximum" - op: "Maximum" - input: "encoder/block_001/layer_000/SelfAttention/o/adafactor/sqrt/parallel_0/Sqrt" - input: "encoder/block_001/layer_000/SelfAttention/o/adafactor/maximum/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/o/adafactor/scalar_mul/parallel_0/mul" - op: "Mul" - input: "encoder/block_001/layer_000/SelfAttention/o/adafactor/add_1/parallel_0/Maximum" - input: "mul_4" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/o/adafactor/scalar_mul_1/parallel_0/mul" - op: "Mul" - input: "encoder/block_001/layer_000/SelfAttention/o_slot_v/read" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/o/adafactor/scalar_mul_2/parallel_0/mul" - op: "Mul" - input: "encoder/block_001/layer_000/SelfAttention/o/adafactor/scalar_add/parallel_0/add" - input: "encoder/block_001/layer_000/SelfAttention/o/adafactor/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/o/adafactor/add_1_1/parallel_0/Add" - op: "Add" - input: "encoder/block_001/layer_000/SelfAttention/o/adafactor/scalar_mul_1/parallel_0/mul" - input: "encoder/block_001/layer_000/SelfAttention/o/adafactor/scalar_mul_2/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/o/adafactor/rsqrt/parallel_0/Rsqrt" - op: "Rsqrt" - input: "encoder/block_001/layer_000/SelfAttention/o/adafactor/add_1_1/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/o/adafactor/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/Exit_17" - input: "encoder/block_001/layer_000/SelfAttention/o/adafactor/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/o/adafactor/square_2/parallel_0/Square" - op: "Square" - input: "encoder/block_001/layer_000/SelfAttention/o/adafactor/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/o/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/o/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - op: "Sum" - input: "encoder/block_001/layer_000/SelfAttention/o/adafactor/square_2/parallel_0/Square" - input: "encoder/block_001/layer_000/SelfAttention/o/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/o/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.000244140625 - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/o/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - op: "Mul" - input: "encoder/block_001/layer_000/SelfAttention/o/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - input: "encoder/block_001/layer_000/SelfAttention/o/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/o/adafactor/sqrt_1/parallel_0/Sqrt" - op: "Sqrt" - input: "encoder/block_001/layer_000/SelfAttention/o/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/o/adafactor/scalar_mul_3/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/o/adafactor/scalar_mul_3/parallel_0/mul" - op: "Mul" - input: "encoder/block_001/layer_000/SelfAttention/o/adafactor/sqrt_1/parallel_0/Sqrt" - input: "encoder/block_001/layer_000/SelfAttention/o/adafactor/scalar_mul_3/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/o/adafactor/add_2/parallel_0/Maximum" - op: "Maximum" - input: "encoder/block_001/layer_000/SelfAttention/o/adafactor/maximum_1/Const" - input: "encoder/block_001/layer_000/SelfAttention/o/adafactor/scalar_mul_3/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/o/adafactor/reciprocal/parallel_0/Reciprocal" - op: "Reciprocal" - input: "encoder/block_001/layer_000/SelfAttention/o/adafactor/add_2/parallel_0/Maximum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/o/adafactor/einsum_1/parallel_0/Mul" - op: "Mul" - input: "encoder/block_001/layer_000/SelfAttention/o/adafactor/einsum/parallel_0/Mul" - input: "encoder/block_001/layer_000/SelfAttention/o/adafactor/reciprocal/parallel_0/Reciprocal" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_000/SelfAttention/o/adafactor/einsum_2/parallel_0/Mul" - op: "Mul" - input: "encoder/block_001/layer_000/SelfAttention/o/adafactor/einsum_1/parallel_0/Mul" - input: "encoder/block_001/layer_000/SelfAttention/o/adafactor/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/rms_norm/scale_slot_v_1/group_deps" - op: "NoOp" -} -node { - name: "encoder/block_001/layer_001/rms_norm/scale_slot_v_1/group_deps_1" - op: "NoOp" -} -node { - name: "encoder/block_001/layer_001/rms_norm/scale/adafactor/square/parallel_0/Square" - op: "Square" - input: "while_loop/while/Exit_18" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/rms_norm/scale/adafactor/scalar_add/parallel_0/add/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0000000031710769e-30 - } - } - } -} -node { - name: "encoder/block_001/layer_001/rms_norm/scale/adafactor/scalar_add/parallel_0/add" - op: "AddV2" - input: "encoder/block_001/layer_001/rms_norm/scale/adafactor/square/parallel_0/Square" - input: "encoder/block_001/layer_001/rms_norm/scale/adafactor/scalar_add/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/rms_norm/scale/adafactor/square_1/parallel_0/Square" - op: "Square" - input: "encoder/block_001/layer_001/rms_norm/scale_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/rms_norm/scale/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "encoder/block_001/layer_001/rms_norm/scale/adafactor/reduce_mean/reduce/parallel_0/Sum" - op: "Sum" - input: "encoder/block_001/layer_001/rms_norm/scale/adafactor/square_1/parallel_0/Square" - input: "encoder/block_001/layer_001/rms_norm/scale/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "encoder/block_001/layer_001/rms_norm/scale/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.03125 - } - } - } -} -node { - name: "encoder/block_001/layer_001/rms_norm/scale/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - op: "Mul" - input: "encoder/block_001/layer_001/rms_norm/scale/adafactor/reduce_mean/reduce/parallel_0/Sum" - input: "encoder/block_001/layer_001/rms_norm/scale/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/rms_norm/scale/adafactor/sqrt/parallel_0/Sqrt" - op: "Sqrt" - input: "encoder/block_001/layer_001/rms_norm/scale/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/rms_norm/scale/adafactor/add_1/parallel_0/Maximum" - op: "Maximum" - input: "encoder/block_001/layer_001/rms_norm/scale/adafactor/sqrt/parallel_0/Sqrt" - input: "encoder/block_001/layer_001/rms_norm/scale/adafactor/maximum/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/rms_norm/scale/adafactor/scalar_mul/parallel_0/mul" - op: "Mul" - input: "encoder/block_001/layer_001/rms_norm/scale/adafactor/add_1/parallel_0/Maximum" - input: "mul_4" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/rms_norm/scale/adafactor/scalar_mul_1/parallel_0/mul" - op: "Mul" - input: "encoder/block_001/layer_001/rms_norm/scale_slot_v/read" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/rms_norm/scale/adafactor/scalar_mul_2/parallel_0/mul" - op: "Mul" - input: "encoder/block_001/layer_001/rms_norm/scale/adafactor/scalar_add/parallel_0/add" - input: "encoder/block_001/layer_001/rms_norm/scale/adafactor/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/rms_norm/scale/adafactor/add_1_1/parallel_0/Add" - op: "Add" - input: "encoder/block_001/layer_001/rms_norm/scale/adafactor/scalar_mul_1/parallel_0/mul" - input: "encoder/block_001/layer_001/rms_norm/scale/adafactor/scalar_mul_2/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/rms_norm/scale/adafactor/rsqrt/parallel_0/Rsqrt" - op: "Rsqrt" - input: "encoder/block_001/layer_001/rms_norm/scale/adafactor/add_1_1/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/rms_norm/scale/adafactor/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/Exit_18" - input: "encoder/block_001/layer_001/rms_norm/scale/adafactor/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/rms_norm/scale/adafactor/square_2/parallel_0/Square" - op: "Square" - input: "encoder/block_001/layer_001/rms_norm/scale/adafactor/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/rms_norm/scale/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "encoder/block_001/layer_001/rms_norm/scale/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - op: "Sum" - input: "encoder/block_001/layer_001/rms_norm/scale/adafactor/square_2/parallel_0/Square" - input: "encoder/block_001/layer_001/rms_norm/scale/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "encoder/block_001/layer_001/rms_norm/scale/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.03125 - } - } - } -} -node { - name: "encoder/block_001/layer_001/rms_norm/scale/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - op: "Mul" - input: "encoder/block_001/layer_001/rms_norm/scale/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - input: "encoder/block_001/layer_001/rms_norm/scale/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/rms_norm/scale/adafactor/sqrt_1/parallel_0/Sqrt" - op: "Sqrt" - input: "encoder/block_001/layer_001/rms_norm/scale/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/rms_norm/scale/adafactor/scalar_mul_3/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "encoder/block_001/layer_001/rms_norm/scale/adafactor/scalar_mul_3/parallel_0/mul" - op: "Mul" - input: "encoder/block_001/layer_001/rms_norm/scale/adafactor/sqrt_1/parallel_0/Sqrt" - input: "encoder/block_001/layer_001/rms_norm/scale/adafactor/scalar_mul_3/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/rms_norm/scale/adafactor/add_2/parallel_0/Maximum" - op: "Maximum" - input: "encoder/block_001/layer_001/rms_norm/scale/adafactor/maximum_1/Const" - input: "encoder/block_001/layer_001/rms_norm/scale/adafactor/scalar_mul_3/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/rms_norm/scale/adafactor/reciprocal/parallel_0/Reciprocal" - op: "Reciprocal" - input: "encoder/block_001/layer_001/rms_norm/scale/adafactor/add_2/parallel_0/Maximum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/rms_norm/scale/adafactor/einsum_1/parallel_0/Mul" - op: "Mul" - input: "encoder/block_001/layer_001/rms_norm/scale/adafactor/einsum/parallel_0/Mul" - input: "encoder/block_001/layer_001/rms_norm/scale/adafactor/reciprocal/parallel_0/Reciprocal" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/rms_norm/scale/adafactor/einsum_2/parallel_0/Mul" - op: "Mul" - input: "encoder/block_001/layer_001/rms_norm/scale/adafactor/einsum_1/parallel_0/Mul" - input: "encoder/block_001/layer_001/rms_norm/scale/adafactor/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slot_v_1/group_deps" - op: "NoOp" -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slot_v_1/group_deps_1" - op: "NoOp" -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/adafactor/square/parallel_0/Square" - op: "Square" - input: "while_loop/while/Exit_19" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/adafactor/scalar_add/parallel_0/add/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0000000031710769e-30 - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/adafactor/scalar_add/parallel_0/add" - op: "AddV2" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/adafactor/square/parallel_0/Square" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/adafactor/scalar_add/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/adafactor/square_1/parallel_0/Square" - op: "Square" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/adafactor/reduce_mean/reduce/parallel_0/Sum" - op: "Sum" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/adafactor/square_1/parallel_0/Square" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.00048828125 - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - op: "Mul" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/adafactor/reduce_mean/reduce/parallel_0/Sum" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/adafactor/sqrt/parallel_0/Sqrt" - op: "Sqrt" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/adafactor/add_1/parallel_0/Maximum" - op: "Maximum" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/adafactor/sqrt/parallel_0/Sqrt" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/adafactor/maximum/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/adafactor/scalar_mul/parallel_0/mul" - op: "Mul" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/adafactor/add_1/parallel_0/Maximum" - input: "mul_4" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/adafactor/scalar_mul_1/parallel_0/mul" - op: "Mul" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slot_v/read" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/adafactor/scalar_mul_2/parallel_0/mul" - op: "Mul" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/adafactor/scalar_add/parallel_0/add" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/adafactor/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/adafactor/add_1_1/parallel_0/Add" - op: "Add" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/adafactor/scalar_mul_1/parallel_0/mul" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/adafactor/scalar_mul_2/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/adafactor/rsqrt/parallel_0/Rsqrt" - op: "Rsqrt" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/adafactor/add_1_1/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/adafactor/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/Exit_19" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/adafactor/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/adafactor/square_2/parallel_0/Square" - op: "Square" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/adafactor/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - op: "Sum" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/adafactor/square_2/parallel_0/Square" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.00048828125 - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - op: "Mul" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/adafactor/sqrt_1/parallel_0/Sqrt" - op: "Sqrt" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/adafactor/scalar_mul_3/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/adafactor/scalar_mul_3/parallel_0/mul" - op: "Mul" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/adafactor/sqrt_1/parallel_0/Sqrt" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/adafactor/scalar_mul_3/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/adafactor/add_2/parallel_0/Maximum" - op: "Maximum" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/adafactor/maximum_1/Const" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/adafactor/scalar_mul_3/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/adafactor/reciprocal/parallel_0/Reciprocal" - op: "Reciprocal" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/adafactor/add_2/parallel_0/Maximum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/adafactor/einsum_1/parallel_0/Mul" - op: "Mul" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/adafactor/einsum/parallel_0/Mul" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/adafactor/reciprocal/parallel_0/Reciprocal" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/adafactor/einsum_2/parallel_0/Mul" - op: "Mul" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/adafactor/einsum_1/parallel_0/Mul" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/adafactor/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slot_v_1/group_deps" - op: "NoOp" -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slot_v_1/group_deps_1" - op: "NoOp" -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/adafactor/square/parallel_0/Square" - op: "Square" - input: "while_loop/while/Exit_20" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/adafactor/scalar_add/parallel_0/add/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0000000031710769e-30 - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/adafactor/scalar_add/parallel_0/add" - op: "AddV2" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/adafactor/square/parallel_0/Square" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/adafactor/scalar_add/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/adafactor/square_1/parallel_0/Square" - op: "Square" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/adafactor/reduce_mean/reduce/parallel_0/Sum" - op: "Sum" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/adafactor/square_1/parallel_0/Square" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.00048828125 - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - op: "Mul" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/adafactor/reduce_mean/reduce/parallel_0/Sum" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/adafactor/sqrt/parallel_0/Sqrt" - op: "Sqrt" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/adafactor/add_1/parallel_0/Maximum" - op: "Maximum" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/adafactor/sqrt/parallel_0/Sqrt" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/adafactor/maximum/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/adafactor/scalar_mul/parallel_0/mul" - op: "Mul" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/adafactor/add_1/parallel_0/Maximum" - input: "mul_4" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/adafactor/scalar_mul_1/parallel_0/mul" - op: "Mul" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slot_v/read" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/adafactor/scalar_mul_2/parallel_0/mul" - op: "Mul" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/adafactor/scalar_add/parallel_0/add" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/adafactor/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/adafactor/add_1_1/parallel_0/Add" - op: "Add" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/adafactor/scalar_mul_1/parallel_0/mul" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/adafactor/scalar_mul_2/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/adafactor/rsqrt/parallel_0/Rsqrt" - op: "Rsqrt" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/adafactor/add_1_1/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/adafactor/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/Exit_20" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/adafactor/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/adafactor/square_2/parallel_0/Square" - op: "Square" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/adafactor/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - op: "Sum" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/adafactor/square_2/parallel_0/Square" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.00048828125 - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - op: "Mul" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/adafactor/sqrt_1/parallel_0/Sqrt" - op: "Sqrt" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/adafactor/scalar_mul_3/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/adafactor/scalar_mul_3/parallel_0/mul" - op: "Mul" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/adafactor/sqrt_1/parallel_0/Sqrt" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/adafactor/scalar_mul_3/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/adafactor/add_2/parallel_0/Maximum" - op: "Maximum" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/adafactor/maximum_1/Const" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/adafactor/scalar_mul_3/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/adafactor/reciprocal/parallel_0/Reciprocal" - op: "Reciprocal" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/adafactor/add_2/parallel_0/Maximum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/adafactor/einsum_1/parallel_0/Mul" - op: "Mul" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/adafactor/einsum/parallel_0/Mul" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/adafactor/reciprocal/parallel_0/Reciprocal" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/adafactor/einsum_2/parallel_0/Mul" - op: "Mul" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/adafactor/einsum_1/parallel_0/Mul" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/adafactor/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wo/kernel_slot_v_1/group_deps" - op: "NoOp" -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wo/kernel_slot_v_1/group_deps_1" - op: "NoOp" -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/adafactor/square/parallel_0/Square" - op: "Square" - input: "while_loop/while/Exit_21" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/adafactor/scalar_add/parallel_0/add/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0000000031710769e-30 - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/adafactor/scalar_add/parallel_0/add" - op: "AddV2" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/adafactor/square/parallel_0/Square" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/adafactor/scalar_add/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/adafactor/square_1/parallel_0/Square" - op: "Square" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/adafactor/reduce_mean/reduce/parallel_0/Sum" - op: "Sum" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/adafactor/square_1/parallel_0/Square" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.00048828125 - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - op: "Mul" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/adafactor/reduce_mean/reduce/parallel_0/Sum" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/adafactor/sqrt/parallel_0/Sqrt" - op: "Sqrt" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/adafactor/add_1/parallel_0/Maximum" - op: "Maximum" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/adafactor/sqrt/parallel_0/Sqrt" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/adafactor/maximum/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/adafactor/scalar_mul/parallel_0/mul" - op: "Mul" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/adafactor/add_1/parallel_0/Maximum" - input: "mul_4" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/adafactor/scalar_mul_1/parallel_0/mul" - op: "Mul" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel_slot_v/read" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/adafactor/scalar_mul_2/parallel_0/mul" - op: "Mul" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/adafactor/scalar_add/parallel_0/add" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/adafactor/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/adafactor/add_1_1/parallel_0/Add" - op: "Add" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/adafactor/scalar_mul_1/parallel_0/mul" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/adafactor/scalar_mul_2/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/adafactor/rsqrt/parallel_0/Rsqrt" - op: "Rsqrt" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/adafactor/add_1_1/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/adafactor/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/Exit_21" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/adafactor/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/adafactor/square_2/parallel_0/Square" - op: "Square" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/adafactor/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - op: "Sum" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/adafactor/square_2/parallel_0/Square" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.00048828125 - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - op: "Mul" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/adafactor/sqrt_1/parallel_0/Sqrt" - op: "Sqrt" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/adafactor/scalar_mul_3/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/adafactor/scalar_mul_3/parallel_0/mul" - op: "Mul" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/adafactor/sqrt_1/parallel_0/Sqrt" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/adafactor/scalar_mul_3/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/adafactor/add_2/parallel_0/Maximum" - op: "Maximum" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/adafactor/maximum_1/Const" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/adafactor/scalar_mul_3/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/adafactor/reciprocal/parallel_0/Reciprocal" - op: "Reciprocal" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/adafactor/add_2/parallel_0/Maximum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/adafactor/einsum_1/parallel_0/Mul" - op: "Mul" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/adafactor/einsum/parallel_0/Mul" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/adafactor/reciprocal/parallel_0/Reciprocal" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/adafactor/einsum_2/parallel_0/Mul" - op: "Mul" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/adafactor/einsum_1/parallel_0/Mul" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/adafactor/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/rms_norm/scale_slot_v_1/group_deps" - op: "NoOp" -} -node { - name: "encoder/rms_norm/scale_slot_v_1/group_deps_1" - op: "NoOp" -} -node { - name: "encoder/rms_norm/scale/adafactor/square/parallel_0/Square" - op: "Square" - input: "while_loop/while/Exit_22" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/rms_norm/scale/adafactor/scalar_add/parallel_0/add/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0000000031710769e-30 - } - } - } -} -node { - name: "encoder/rms_norm/scale/adafactor/scalar_add/parallel_0/add" - op: "AddV2" - input: "encoder/rms_norm/scale/adafactor/square/parallel_0/Square" - input: "encoder/rms_norm/scale/adafactor/scalar_add/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/rms_norm/scale/adafactor/square_1/parallel_0/Square" - op: "Square" - input: "encoder/rms_norm/scale_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/rms_norm/scale/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "encoder/rms_norm/scale/adafactor/reduce_mean/reduce/parallel_0/Sum" - op: "Sum" - input: "encoder/rms_norm/scale/adafactor/square_1/parallel_0/Square" - input: "encoder/rms_norm/scale/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "encoder/rms_norm/scale/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.03125 - } - } - } -} -node { - name: "encoder/rms_norm/scale/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - op: "Mul" - input: "encoder/rms_norm/scale/adafactor/reduce_mean/reduce/parallel_0/Sum" - input: "encoder/rms_norm/scale/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/rms_norm/scale/adafactor/sqrt/parallel_0/Sqrt" - op: "Sqrt" - input: "encoder/rms_norm/scale/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/rms_norm/scale/adafactor/add_1/parallel_0/Maximum" - op: "Maximum" - input: "encoder/rms_norm/scale/adafactor/sqrt/parallel_0/Sqrt" - input: "encoder/rms_norm/scale/adafactor/maximum/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/rms_norm/scale/adafactor/scalar_mul/parallel_0/mul" - op: "Mul" - input: "encoder/rms_norm/scale/adafactor/add_1/parallel_0/Maximum" - input: "mul_4" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/rms_norm/scale/adafactor/scalar_mul_1/parallel_0/mul" - op: "Mul" - input: "encoder/rms_norm/scale_slot_v/read" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/rms_norm/scale/adafactor/scalar_mul_2/parallel_0/mul" - op: "Mul" - input: "encoder/rms_norm/scale/adafactor/scalar_add/parallel_0/add" - input: "encoder/rms_norm/scale/adafactor/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/rms_norm/scale/adafactor/add_1_1/parallel_0/Add" - op: "Add" - input: "encoder/rms_norm/scale/adafactor/scalar_mul_1/parallel_0/mul" - input: "encoder/rms_norm/scale/adafactor/scalar_mul_2/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/rms_norm/scale/adafactor/rsqrt/parallel_0/Rsqrt" - op: "Rsqrt" - input: "encoder/rms_norm/scale/adafactor/add_1_1/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/rms_norm/scale/adafactor/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/Exit_22" - input: "encoder/rms_norm/scale/adafactor/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/rms_norm/scale/adafactor/square_2/parallel_0/Square" - op: "Square" - input: "encoder/rms_norm/scale/adafactor/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/rms_norm/scale/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "encoder/rms_norm/scale/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - op: "Sum" - input: "encoder/rms_norm/scale/adafactor/square_2/parallel_0/Square" - input: "encoder/rms_norm/scale/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "encoder/rms_norm/scale/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.03125 - } - } - } -} -node { - name: "encoder/rms_norm/scale/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - op: "Mul" - input: "encoder/rms_norm/scale/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - input: "encoder/rms_norm/scale/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/rms_norm/scale/adafactor/sqrt_1/parallel_0/Sqrt" - op: "Sqrt" - input: "encoder/rms_norm/scale/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/rms_norm/scale/adafactor/scalar_mul_3/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "encoder/rms_norm/scale/adafactor/scalar_mul_3/parallel_0/mul" - op: "Mul" - input: "encoder/rms_norm/scale/adafactor/sqrt_1/parallel_0/Sqrt" - input: "encoder/rms_norm/scale/adafactor/scalar_mul_3/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/rms_norm/scale/adafactor/add_2/parallel_0/Maximum" - op: "Maximum" - input: "encoder/rms_norm/scale/adafactor/maximum_1/Const" - input: "encoder/rms_norm/scale/adafactor/scalar_mul_3/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/rms_norm/scale/adafactor/reciprocal/parallel_0/Reciprocal" - op: "Reciprocal" - input: "encoder/rms_norm/scale/adafactor/add_2/parallel_0/Maximum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "encoder/rms_norm/scale/adafactor/einsum_1/parallel_0/Mul" - op: "Mul" - input: "encoder/rms_norm/scale/adafactor/einsum/parallel_0/Mul" - input: "encoder/rms_norm/scale/adafactor/reciprocal/parallel_0/Reciprocal" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "encoder/rms_norm/scale/adafactor/einsum_2/parallel_0/Mul" - op: "Mul" - input: "encoder/rms_norm/scale/adafactor/einsum_1/parallel_0/Mul" - input: "encoder/rms_norm/scale/adafactor/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/rms_norm/scale_slot_v_1/group_deps" - op: "NoOp" -} -node { - name: "decoder/block_000/layer_000/rms_norm/scale_slot_v_1/group_deps_1" - op: "NoOp" -} -node { - name: "decoder/block_000/layer_000/rms_norm/scale/adafactor/square/parallel_0/Square" - op: "Square" - input: "while_loop/while/Exit_23" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/rms_norm/scale/adafactor/scalar_add/parallel_0/add/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0000000031710769e-30 - } - } - } -} -node { - name: "decoder/block_000/layer_000/rms_norm/scale/adafactor/scalar_add/parallel_0/add" - op: "AddV2" - input: "decoder/block_000/layer_000/rms_norm/scale/adafactor/square/parallel_0/Square" - input: "decoder/block_000/layer_000/rms_norm/scale/adafactor/scalar_add/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/rms_norm/scale/adafactor/square_1/parallel_0/Square" - op: "Square" - input: "decoder/block_000/layer_000/rms_norm/scale_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/rms_norm/scale/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "decoder/block_000/layer_000/rms_norm/scale/adafactor/reduce_mean/reduce/parallel_0/Sum" - op: "Sum" - input: "decoder/block_000/layer_000/rms_norm/scale/adafactor/square_1/parallel_0/Square" - input: "decoder/block_000/layer_000/rms_norm/scale/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "decoder/block_000/layer_000/rms_norm/scale/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.03125 - } - } - } -} -node { - name: "decoder/block_000/layer_000/rms_norm/scale/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_000/rms_norm/scale/adafactor/reduce_mean/reduce/parallel_0/Sum" - input: "decoder/block_000/layer_000/rms_norm/scale/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/rms_norm/scale/adafactor/sqrt/parallel_0/Sqrt" - op: "Sqrt" - input: "decoder/block_000/layer_000/rms_norm/scale/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/rms_norm/scale/adafactor/add_1/parallel_0/Maximum" - op: "Maximum" - input: "decoder/block_000/layer_000/rms_norm/scale/adafactor/sqrt/parallel_0/Sqrt" - input: "decoder/block_000/layer_000/rms_norm/scale/adafactor/maximum/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/rms_norm/scale/adafactor/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_000/rms_norm/scale/adafactor/add_1/parallel_0/Maximum" - input: "mul_4" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/rms_norm/scale/adafactor/scalar_mul_1/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_000/rms_norm/scale_slot_v/read" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/rms_norm/scale/adafactor/scalar_mul_2/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_000/rms_norm/scale/adafactor/scalar_add/parallel_0/add" - input: "decoder/block_000/layer_000/rms_norm/scale/adafactor/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/rms_norm/scale/adafactor/add_1_1/parallel_0/Add" - op: "Add" - input: "decoder/block_000/layer_000/rms_norm/scale/adafactor/scalar_mul_1/parallel_0/mul" - input: "decoder/block_000/layer_000/rms_norm/scale/adafactor/scalar_mul_2/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/rms_norm/scale/adafactor/rsqrt/parallel_0/Rsqrt" - op: "Rsqrt" - input: "decoder/block_000/layer_000/rms_norm/scale/adafactor/add_1_1/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/rms_norm/scale/adafactor/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/Exit_23" - input: "decoder/block_000/layer_000/rms_norm/scale/adafactor/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/rms_norm/scale/adafactor/square_2/parallel_0/Square" - op: "Square" - input: "decoder/block_000/layer_000/rms_norm/scale/adafactor/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/rms_norm/scale/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "decoder/block_000/layer_000/rms_norm/scale/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - op: "Sum" - input: "decoder/block_000/layer_000/rms_norm/scale/adafactor/square_2/parallel_0/Square" - input: "decoder/block_000/layer_000/rms_norm/scale/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "decoder/block_000/layer_000/rms_norm/scale/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.03125 - } - } - } -} -node { - name: "decoder/block_000/layer_000/rms_norm/scale/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_000/rms_norm/scale/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - input: "decoder/block_000/layer_000/rms_norm/scale/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/rms_norm/scale/adafactor/sqrt_1/parallel_0/Sqrt" - op: "Sqrt" - input: "decoder/block_000/layer_000/rms_norm/scale/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/rms_norm/scale/adafactor/scalar_mul_3/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_000/layer_000/rms_norm/scale/adafactor/scalar_mul_3/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_000/rms_norm/scale/adafactor/sqrt_1/parallel_0/Sqrt" - input: "decoder/block_000/layer_000/rms_norm/scale/adafactor/scalar_mul_3/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/rms_norm/scale/adafactor/add_2/parallel_0/Maximum" - op: "Maximum" - input: "decoder/block_000/layer_000/rms_norm/scale/adafactor/maximum_1/Const" - input: "decoder/block_000/layer_000/rms_norm/scale/adafactor/scalar_mul_3/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/rms_norm/scale/adafactor/reciprocal/parallel_0/Reciprocal" - op: "Reciprocal" - input: "decoder/block_000/layer_000/rms_norm/scale/adafactor/add_2/parallel_0/Maximum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/rms_norm/scale/adafactor/einsum_1/parallel_0/Mul" - op: "Mul" - input: "decoder/block_000/layer_000/rms_norm/scale/adafactor/einsum/parallel_0/Mul" - input: "decoder/block_000/layer_000/rms_norm/scale/adafactor/reciprocal/parallel_0/Reciprocal" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/rms_norm/scale/adafactor/einsum_2/parallel_0/Mul" - op: "Mul" - input: "decoder/block_000/layer_000/rms_norm/scale/adafactor/einsum_1/parallel_0/Mul" - input: "decoder/block_000/layer_000/rms_norm/scale/adafactor/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/q_slot_v_1/group_deps" - op: "NoOp" -} -node { - name: "decoder/block_000/layer_000/SelfAttention/q_slot_v_1/group_deps_1" - op: "NoOp" -} -node { - name: "decoder/block_000/layer_000/SelfAttention/q/adafactor/square/parallel_0/Square" - op: "Square" - input: "while_loop/while/Exit_24" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/q/adafactor/scalar_add/parallel_0/add/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0000000031710769e-30 - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/q/adafactor/scalar_add/parallel_0/add" - op: "AddV2" - input: "decoder/block_000/layer_000/SelfAttention/q/adafactor/square/parallel_0/Square" - input: "decoder/block_000/layer_000/SelfAttention/q/adafactor/scalar_add/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/q/adafactor/square_1/parallel_0/Square" - op: "Square" - input: "decoder/block_000/layer_000/SelfAttention/q_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/q/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/q/adafactor/reduce_mean/reduce/parallel_0/Sum" - op: "Sum" - input: "decoder/block_000/layer_000/SelfAttention/q/adafactor/square_1/parallel_0/Square" - input: "decoder/block_000/layer_000/SelfAttention/q/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/q/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.000244140625 - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/q/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_000/SelfAttention/q/adafactor/reduce_mean/reduce/parallel_0/Sum" - input: "decoder/block_000/layer_000/SelfAttention/q/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/q/adafactor/sqrt/parallel_0/Sqrt" - op: "Sqrt" - input: "decoder/block_000/layer_000/SelfAttention/q/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/q/adafactor/add_1/parallel_0/Maximum" - op: "Maximum" - input: "decoder/block_000/layer_000/SelfAttention/q/adafactor/sqrt/parallel_0/Sqrt" - input: "decoder/block_000/layer_000/SelfAttention/q/adafactor/maximum/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/q/adafactor/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_000/SelfAttention/q/adafactor/add_1/parallel_0/Maximum" - input: "mul_4" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/q/adafactor/scalar_mul_1/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_000/SelfAttention/q_slot_v/read" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/q/adafactor/scalar_mul_2/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_000/SelfAttention/q/adafactor/scalar_add/parallel_0/add" - input: "decoder/block_000/layer_000/SelfAttention/q/adafactor/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/q/adafactor/add_1_1/parallel_0/Add" - op: "Add" - input: "decoder/block_000/layer_000/SelfAttention/q/adafactor/scalar_mul_1/parallel_0/mul" - input: "decoder/block_000/layer_000/SelfAttention/q/adafactor/scalar_mul_2/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/q/adafactor/rsqrt/parallel_0/Rsqrt" - op: "Rsqrt" - input: "decoder/block_000/layer_000/SelfAttention/q/adafactor/add_1_1/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/q/adafactor/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/Exit_24" - input: "decoder/block_000/layer_000/SelfAttention/q/adafactor/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/q/adafactor/square_2/parallel_0/Square" - op: "Square" - input: "decoder/block_000/layer_000/SelfAttention/q/adafactor/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/q/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/q/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - op: "Sum" - input: "decoder/block_000/layer_000/SelfAttention/q/adafactor/square_2/parallel_0/Square" - input: "decoder/block_000/layer_000/SelfAttention/q/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/q/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.000244140625 - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/q/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_000/SelfAttention/q/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - input: "decoder/block_000/layer_000/SelfAttention/q/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/q/adafactor/sqrt_1/parallel_0/Sqrt" - op: "Sqrt" - input: "decoder/block_000/layer_000/SelfAttention/q/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/q/adafactor/scalar_mul_3/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/q/adafactor/scalar_mul_3/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_000/SelfAttention/q/adafactor/sqrt_1/parallel_0/Sqrt" - input: "decoder/block_000/layer_000/SelfAttention/q/adafactor/scalar_mul_3/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/q/adafactor/add_2/parallel_0/Maximum" - op: "Maximum" - input: "decoder/block_000/layer_000/SelfAttention/q/adafactor/maximum_1/Const" - input: "decoder/block_000/layer_000/SelfAttention/q/adafactor/scalar_mul_3/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/q/adafactor/reciprocal/parallel_0/Reciprocal" - op: "Reciprocal" - input: "decoder/block_000/layer_000/SelfAttention/q/adafactor/add_2/parallel_0/Maximum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/q/adafactor/einsum_1/parallel_0/Mul" - op: "Mul" - input: "decoder/block_000/layer_000/SelfAttention/q/adafactor/einsum/parallel_0/Mul" - input: "decoder/block_000/layer_000/SelfAttention/q/adafactor/reciprocal/parallel_0/Reciprocal" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/q/adafactor/einsum_2/parallel_0/Mul" - op: "Mul" - input: "decoder/block_000/layer_000/SelfAttention/q/adafactor/einsum_1/parallel_0/Mul" - input: "decoder/block_000/layer_000/SelfAttention/q/adafactor/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/k_slot_v_1/group_deps" - op: "NoOp" -} -node { - name: "decoder/block_000/layer_000/SelfAttention/k_slot_v_1/group_deps_1" - op: "NoOp" -} -node { - name: "decoder/block_000/layer_000/SelfAttention/k/adafactor/square/parallel_0/Square" - op: "Square" - input: "while_loop/while/Exit_25" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/k/adafactor/scalar_add/parallel_0/add/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0000000031710769e-30 - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/k/adafactor/scalar_add/parallel_0/add" - op: "AddV2" - input: "decoder/block_000/layer_000/SelfAttention/k/adafactor/square/parallel_0/Square" - input: "decoder/block_000/layer_000/SelfAttention/k/adafactor/scalar_add/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/k/adafactor/square_1/parallel_0/Square" - op: "Square" - input: "decoder/block_000/layer_000/SelfAttention/k_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/k/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/k/adafactor/reduce_mean/reduce/parallel_0/Sum" - op: "Sum" - input: "decoder/block_000/layer_000/SelfAttention/k/adafactor/square_1/parallel_0/Square" - input: "decoder/block_000/layer_000/SelfAttention/k/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/k/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.000244140625 - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/k/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_000/SelfAttention/k/adafactor/reduce_mean/reduce/parallel_0/Sum" - input: "decoder/block_000/layer_000/SelfAttention/k/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/k/adafactor/sqrt/parallel_0/Sqrt" - op: "Sqrt" - input: "decoder/block_000/layer_000/SelfAttention/k/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/k/adafactor/add_1/parallel_0/Maximum" - op: "Maximum" - input: "decoder/block_000/layer_000/SelfAttention/k/adafactor/sqrt/parallel_0/Sqrt" - input: "decoder/block_000/layer_000/SelfAttention/k/adafactor/maximum/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/k/adafactor/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_000/SelfAttention/k/adafactor/add_1/parallel_0/Maximum" - input: "mul_4" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/k/adafactor/scalar_mul_1/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_000/SelfAttention/k_slot_v/read" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/k/adafactor/scalar_mul_2/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_000/SelfAttention/k/adafactor/scalar_add/parallel_0/add" - input: "decoder/block_000/layer_000/SelfAttention/k/adafactor/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/k/adafactor/add_1_1/parallel_0/Add" - op: "Add" - input: "decoder/block_000/layer_000/SelfAttention/k/adafactor/scalar_mul_1/parallel_0/mul" - input: "decoder/block_000/layer_000/SelfAttention/k/adafactor/scalar_mul_2/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/k/adafactor/rsqrt/parallel_0/Rsqrt" - op: "Rsqrt" - input: "decoder/block_000/layer_000/SelfAttention/k/adafactor/add_1_1/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/k/adafactor/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/Exit_25" - input: "decoder/block_000/layer_000/SelfAttention/k/adafactor/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/k/adafactor/square_2/parallel_0/Square" - op: "Square" - input: "decoder/block_000/layer_000/SelfAttention/k/adafactor/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/k/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/k/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - op: "Sum" - input: "decoder/block_000/layer_000/SelfAttention/k/adafactor/square_2/parallel_0/Square" - input: "decoder/block_000/layer_000/SelfAttention/k/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/k/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.000244140625 - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/k/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_000/SelfAttention/k/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - input: "decoder/block_000/layer_000/SelfAttention/k/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/k/adafactor/sqrt_1/parallel_0/Sqrt" - op: "Sqrt" - input: "decoder/block_000/layer_000/SelfAttention/k/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/k/adafactor/scalar_mul_3/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/k/adafactor/scalar_mul_3/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_000/SelfAttention/k/adafactor/sqrt_1/parallel_0/Sqrt" - input: "decoder/block_000/layer_000/SelfAttention/k/adafactor/scalar_mul_3/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/k/adafactor/add_2/parallel_0/Maximum" - op: "Maximum" - input: "decoder/block_000/layer_000/SelfAttention/k/adafactor/maximum_1/Const" - input: "decoder/block_000/layer_000/SelfAttention/k/adafactor/scalar_mul_3/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/k/adafactor/reciprocal/parallel_0/Reciprocal" - op: "Reciprocal" - input: "decoder/block_000/layer_000/SelfAttention/k/adafactor/add_2/parallel_0/Maximum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/k/adafactor/einsum_1/parallel_0/Mul" - op: "Mul" - input: "decoder/block_000/layer_000/SelfAttention/k/adafactor/einsum/parallel_0/Mul" - input: "decoder/block_000/layer_000/SelfAttention/k/adafactor/reciprocal/parallel_0/Reciprocal" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/k/adafactor/einsum_2/parallel_0/Mul" - op: "Mul" - input: "decoder/block_000/layer_000/SelfAttention/k/adafactor/einsum_1/parallel_0/Mul" - input: "decoder/block_000/layer_000/SelfAttention/k/adafactor/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/v_slot_v_1/group_deps" - op: "NoOp" -} -node { - name: "decoder/block_000/layer_000/SelfAttention/v_slot_v_1/group_deps_1" - op: "NoOp" -} -node { - name: "decoder/block_000/layer_000/SelfAttention/v/adafactor/square/parallel_0/Square" - op: "Square" - input: "while_loop/while/Exit_26" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/v/adafactor/scalar_add/parallel_0/add/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0000000031710769e-30 - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/v/adafactor/scalar_add/parallel_0/add" - op: "AddV2" - input: "decoder/block_000/layer_000/SelfAttention/v/adafactor/square/parallel_0/Square" - input: "decoder/block_000/layer_000/SelfAttention/v/adafactor/scalar_add/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/v/adafactor/square_1/parallel_0/Square" - op: "Square" - input: "decoder/block_000/layer_000/SelfAttention/v_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/v/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/v/adafactor/reduce_mean/reduce/parallel_0/Sum" - op: "Sum" - input: "decoder/block_000/layer_000/SelfAttention/v/adafactor/square_1/parallel_0/Square" - input: "decoder/block_000/layer_000/SelfAttention/v/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/v/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.000244140625 - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/v/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_000/SelfAttention/v/adafactor/reduce_mean/reduce/parallel_0/Sum" - input: "decoder/block_000/layer_000/SelfAttention/v/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/v/adafactor/sqrt/parallel_0/Sqrt" - op: "Sqrt" - input: "decoder/block_000/layer_000/SelfAttention/v/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/v/adafactor/add_1/parallel_0/Maximum" - op: "Maximum" - input: "decoder/block_000/layer_000/SelfAttention/v/adafactor/sqrt/parallel_0/Sqrt" - input: "decoder/block_000/layer_000/SelfAttention/v/adafactor/maximum/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/v/adafactor/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_000/SelfAttention/v/adafactor/add_1/parallel_0/Maximum" - input: "mul_4" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/v/adafactor/scalar_mul_1/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_000/SelfAttention/v_slot_v/read" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/v/adafactor/scalar_mul_2/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_000/SelfAttention/v/adafactor/scalar_add/parallel_0/add" - input: "decoder/block_000/layer_000/SelfAttention/v/adafactor/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/v/adafactor/add_1_1/parallel_0/Add" - op: "Add" - input: "decoder/block_000/layer_000/SelfAttention/v/adafactor/scalar_mul_1/parallel_0/mul" - input: "decoder/block_000/layer_000/SelfAttention/v/adafactor/scalar_mul_2/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/v/adafactor/rsqrt/parallel_0/Rsqrt" - op: "Rsqrt" - input: "decoder/block_000/layer_000/SelfAttention/v/adafactor/add_1_1/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/v/adafactor/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/Exit_26" - input: "decoder/block_000/layer_000/SelfAttention/v/adafactor/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/v/adafactor/square_2/parallel_0/Square" - op: "Square" - input: "decoder/block_000/layer_000/SelfAttention/v/adafactor/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/v/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/v/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - op: "Sum" - input: "decoder/block_000/layer_000/SelfAttention/v/adafactor/square_2/parallel_0/Square" - input: "decoder/block_000/layer_000/SelfAttention/v/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/v/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.000244140625 - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/v/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_000/SelfAttention/v/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - input: "decoder/block_000/layer_000/SelfAttention/v/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/v/adafactor/sqrt_1/parallel_0/Sqrt" - op: "Sqrt" - input: "decoder/block_000/layer_000/SelfAttention/v/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/v/adafactor/scalar_mul_3/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/v/adafactor/scalar_mul_3/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_000/SelfAttention/v/adafactor/sqrt_1/parallel_0/Sqrt" - input: "decoder/block_000/layer_000/SelfAttention/v/adafactor/scalar_mul_3/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/v/adafactor/add_2/parallel_0/Maximum" - op: "Maximum" - input: "decoder/block_000/layer_000/SelfAttention/v/adafactor/maximum_1/Const" - input: "decoder/block_000/layer_000/SelfAttention/v/adafactor/scalar_mul_3/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/v/adafactor/reciprocal/parallel_0/Reciprocal" - op: "Reciprocal" - input: "decoder/block_000/layer_000/SelfAttention/v/adafactor/add_2/parallel_0/Maximum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/v/adafactor/einsum_1/parallel_0/Mul" - op: "Mul" - input: "decoder/block_000/layer_000/SelfAttention/v/adafactor/einsum/parallel_0/Mul" - input: "decoder/block_000/layer_000/SelfAttention/v/adafactor/reciprocal/parallel_0/Reciprocal" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/v/adafactor/einsum_2/parallel_0/Mul" - op: "Mul" - input: "decoder/block_000/layer_000/SelfAttention/v/adafactor/einsum_1/parallel_0/Mul" - input: "decoder/block_000/layer_000/SelfAttention/v/adafactor/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/o_slot_v_1/group_deps" - op: "NoOp" -} -node { - name: "decoder/block_000/layer_000/SelfAttention/o_slot_v_1/group_deps_1" - op: "NoOp" -} -node { - name: "decoder/block_000/layer_000/SelfAttention/o/adafactor/square/parallel_0/Square" - op: "Square" - input: "while_loop/while/Exit_27" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/o/adafactor/scalar_add/parallel_0/add/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0000000031710769e-30 - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/o/adafactor/scalar_add/parallel_0/add" - op: "AddV2" - input: "decoder/block_000/layer_000/SelfAttention/o/adafactor/square/parallel_0/Square" - input: "decoder/block_000/layer_000/SelfAttention/o/adafactor/scalar_add/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/o/adafactor/square_1/parallel_0/Square" - op: "Square" - input: "decoder/block_000/layer_000/SelfAttention/o_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/o/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/o/adafactor/reduce_mean/reduce/parallel_0/Sum" - op: "Sum" - input: "decoder/block_000/layer_000/SelfAttention/o/adafactor/square_1/parallel_0/Square" - input: "decoder/block_000/layer_000/SelfAttention/o/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/o/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.000244140625 - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/o/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_000/SelfAttention/o/adafactor/reduce_mean/reduce/parallel_0/Sum" - input: "decoder/block_000/layer_000/SelfAttention/o/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/o/adafactor/sqrt/parallel_0/Sqrt" - op: "Sqrt" - input: "decoder/block_000/layer_000/SelfAttention/o/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/o/adafactor/add_1/parallel_0/Maximum" - op: "Maximum" - input: "decoder/block_000/layer_000/SelfAttention/o/adafactor/sqrt/parallel_0/Sqrt" - input: "decoder/block_000/layer_000/SelfAttention/o/adafactor/maximum/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/o/adafactor/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_000/SelfAttention/o/adafactor/add_1/parallel_0/Maximum" - input: "mul_4" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/o/adafactor/scalar_mul_1/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_000/SelfAttention/o_slot_v/read" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/o/adafactor/scalar_mul_2/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_000/SelfAttention/o/adafactor/scalar_add/parallel_0/add" - input: "decoder/block_000/layer_000/SelfAttention/o/adafactor/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/o/adafactor/add_1_1/parallel_0/Add" - op: "Add" - input: "decoder/block_000/layer_000/SelfAttention/o/adafactor/scalar_mul_1/parallel_0/mul" - input: "decoder/block_000/layer_000/SelfAttention/o/adafactor/scalar_mul_2/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/o/adafactor/rsqrt/parallel_0/Rsqrt" - op: "Rsqrt" - input: "decoder/block_000/layer_000/SelfAttention/o/adafactor/add_1_1/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/o/adafactor/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/Exit_27" - input: "decoder/block_000/layer_000/SelfAttention/o/adafactor/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/o/adafactor/square_2/parallel_0/Square" - op: "Square" - input: "decoder/block_000/layer_000/SelfAttention/o/adafactor/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/o/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/o/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - op: "Sum" - input: "decoder/block_000/layer_000/SelfAttention/o/adafactor/square_2/parallel_0/Square" - input: "decoder/block_000/layer_000/SelfAttention/o/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/o/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.000244140625 - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/o/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_000/SelfAttention/o/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - input: "decoder/block_000/layer_000/SelfAttention/o/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/o/adafactor/sqrt_1/parallel_0/Sqrt" - op: "Sqrt" - input: "decoder/block_000/layer_000/SelfAttention/o/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/o/adafactor/scalar_mul_3/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/o/adafactor/scalar_mul_3/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_000/SelfAttention/o/adafactor/sqrt_1/parallel_0/Sqrt" - input: "decoder/block_000/layer_000/SelfAttention/o/adafactor/scalar_mul_3/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/o/adafactor/add_2/parallel_0/Maximum" - op: "Maximum" - input: "decoder/block_000/layer_000/SelfAttention/o/adafactor/maximum_1/Const" - input: "decoder/block_000/layer_000/SelfAttention/o/adafactor/scalar_mul_3/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/o/adafactor/reciprocal/parallel_0/Reciprocal" - op: "Reciprocal" - input: "decoder/block_000/layer_000/SelfAttention/o/adafactor/add_2/parallel_0/Maximum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/o/adafactor/einsum_1/parallel_0/Mul" - op: "Mul" - input: "decoder/block_000/layer_000/SelfAttention/o/adafactor/einsum/parallel_0/Mul" - input: "decoder/block_000/layer_000/SelfAttention/o/adafactor/reciprocal/parallel_0/Reciprocal" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/o/adafactor/einsum_2/parallel_0/Mul" - op: "Mul" - input: "decoder/block_000/layer_000/SelfAttention/o/adafactor/einsum_1/parallel_0/Mul" - input: "decoder/block_000/layer_000/SelfAttention/o/adafactor/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slot_v_1/group_deps" - op: "NoOp" -} -node { - name: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slot_v_1/group_deps_1" - op: "NoOp" -} -node { - name: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/square/parallel_0/Square" - op: "Square" - input: "while_loop/while/Exit_28" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/scalar_add/parallel_0/add/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0000000031710769e-30 - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/scalar_add/parallel_0/add" - op: "AddV2" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/square/parallel_0/Square" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/scalar_add/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/square_1/parallel_0/Square" - op: "Square" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/reduce_mean/reduce/parallel_0/Sum" - op: "Sum" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/square_1/parallel_0/Square" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.015625 - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/reduce_mean/reduce/parallel_0/Sum" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/sqrt/parallel_0/Sqrt" - op: "Sqrt" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/add_1/parallel_0/Maximum" - op: "Maximum" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/sqrt/parallel_0/Sqrt" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/maximum/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/add_1/parallel_0/Maximum" - input: "mul_4" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/scalar_mul_1/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slot_v/read" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/scalar_mul_2/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/scalar_add/parallel_0/add" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/add_1_1/parallel_0/Add" - op: "Add" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/scalar_mul_1/parallel_0/mul" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/scalar_mul_2/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/rsqrt/parallel_0/Rsqrt" - op: "Rsqrt" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/add_1_1/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/Exit_28" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/square_2/parallel_0/Square" - op: "Square" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - op: "Sum" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/square_2/parallel_0/Square" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.015625 - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/sqrt_1/parallel_0/Sqrt" - op: "Sqrt" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/scalar_mul_3/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/scalar_mul_3/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/sqrt_1/parallel_0/Sqrt" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/scalar_mul_3/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/add_2/parallel_0/Maximum" - op: "Maximum" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/maximum_1/Const" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/scalar_mul_3/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/reciprocal/parallel_0/Reciprocal" - op: "Reciprocal" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/add_2/parallel_0/Maximum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/einsum_1/parallel_0/Mul" - op: "Mul" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/einsum/parallel_0/Mul" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/reciprocal/parallel_0/Reciprocal" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/einsum_2/parallel_0/Mul" - op: "Mul" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/einsum_1/parallel_0/Mul" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/rms_norm/scale_slot_v_1/group_deps" - op: "NoOp" -} -node { - name: "decoder/block_000/layer_001/rms_norm/scale_slot_v_1/group_deps_1" - op: "NoOp" -} -node { - name: "decoder/block_000/layer_001/rms_norm/scale/adafactor/square/parallel_0/Square" - op: "Square" - input: "while_loop/while/Exit_29" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/rms_norm/scale/adafactor/scalar_add/parallel_0/add/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0000000031710769e-30 - } - } - } -} -node { - name: "decoder/block_000/layer_001/rms_norm/scale/adafactor/scalar_add/parallel_0/add" - op: "AddV2" - input: "decoder/block_000/layer_001/rms_norm/scale/adafactor/square/parallel_0/Square" - input: "decoder/block_000/layer_001/rms_norm/scale/adafactor/scalar_add/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/rms_norm/scale/adafactor/square_1/parallel_0/Square" - op: "Square" - input: "decoder/block_000/layer_001/rms_norm/scale_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/rms_norm/scale/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "decoder/block_000/layer_001/rms_norm/scale/adafactor/reduce_mean/reduce/parallel_0/Sum" - op: "Sum" - input: "decoder/block_000/layer_001/rms_norm/scale/adafactor/square_1/parallel_0/Square" - input: "decoder/block_000/layer_001/rms_norm/scale/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "decoder/block_000/layer_001/rms_norm/scale/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.03125 - } - } - } -} -node { - name: "decoder/block_000/layer_001/rms_norm/scale/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_001/rms_norm/scale/adafactor/reduce_mean/reduce/parallel_0/Sum" - input: "decoder/block_000/layer_001/rms_norm/scale/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/rms_norm/scale/adafactor/sqrt/parallel_0/Sqrt" - op: "Sqrt" - input: "decoder/block_000/layer_001/rms_norm/scale/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/rms_norm/scale/adafactor/add_1/parallel_0/Maximum" - op: "Maximum" - input: "decoder/block_000/layer_001/rms_norm/scale/adafactor/sqrt/parallel_0/Sqrt" - input: "decoder/block_000/layer_001/rms_norm/scale/adafactor/maximum/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/rms_norm/scale/adafactor/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_001/rms_norm/scale/adafactor/add_1/parallel_0/Maximum" - input: "mul_4" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/rms_norm/scale/adafactor/scalar_mul_1/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_001/rms_norm/scale_slot_v/read" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/rms_norm/scale/adafactor/scalar_mul_2/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_001/rms_norm/scale/adafactor/scalar_add/parallel_0/add" - input: "decoder/block_000/layer_001/rms_norm/scale/adafactor/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/rms_norm/scale/adafactor/add_1_1/parallel_0/Add" - op: "Add" - input: "decoder/block_000/layer_001/rms_norm/scale/adafactor/scalar_mul_1/parallel_0/mul" - input: "decoder/block_000/layer_001/rms_norm/scale/adafactor/scalar_mul_2/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/rms_norm/scale/adafactor/rsqrt/parallel_0/Rsqrt" - op: "Rsqrt" - input: "decoder/block_000/layer_001/rms_norm/scale/adafactor/add_1_1/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/rms_norm/scale/adafactor/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/Exit_29" - input: "decoder/block_000/layer_001/rms_norm/scale/adafactor/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/rms_norm/scale/adafactor/square_2/parallel_0/Square" - op: "Square" - input: "decoder/block_000/layer_001/rms_norm/scale/adafactor/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/rms_norm/scale/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "decoder/block_000/layer_001/rms_norm/scale/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - op: "Sum" - input: "decoder/block_000/layer_001/rms_norm/scale/adafactor/square_2/parallel_0/Square" - input: "decoder/block_000/layer_001/rms_norm/scale/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "decoder/block_000/layer_001/rms_norm/scale/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.03125 - } - } - } -} -node { - name: "decoder/block_000/layer_001/rms_norm/scale/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_001/rms_norm/scale/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - input: "decoder/block_000/layer_001/rms_norm/scale/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/rms_norm/scale/adafactor/sqrt_1/parallel_0/Sqrt" - op: "Sqrt" - input: "decoder/block_000/layer_001/rms_norm/scale/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/rms_norm/scale/adafactor/scalar_mul_3/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_000/layer_001/rms_norm/scale/adafactor/scalar_mul_3/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_001/rms_norm/scale/adafactor/sqrt_1/parallel_0/Sqrt" - input: "decoder/block_000/layer_001/rms_norm/scale/adafactor/scalar_mul_3/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/rms_norm/scale/adafactor/add_2/parallel_0/Maximum" - op: "Maximum" - input: "decoder/block_000/layer_001/rms_norm/scale/adafactor/maximum_1/Const" - input: "decoder/block_000/layer_001/rms_norm/scale/adafactor/scalar_mul_3/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/rms_norm/scale/adafactor/reciprocal/parallel_0/Reciprocal" - op: "Reciprocal" - input: "decoder/block_000/layer_001/rms_norm/scale/adafactor/add_2/parallel_0/Maximum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/rms_norm/scale/adafactor/einsum_1/parallel_0/Mul" - op: "Mul" - input: "decoder/block_000/layer_001/rms_norm/scale/adafactor/einsum/parallel_0/Mul" - input: "decoder/block_000/layer_001/rms_norm/scale/adafactor/reciprocal/parallel_0/Reciprocal" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/rms_norm/scale/adafactor/einsum_2/parallel_0/Mul" - op: "Mul" - input: "decoder/block_000/layer_001/rms_norm/scale/adafactor/einsum_1/parallel_0/Mul" - input: "decoder/block_000/layer_001/rms_norm/scale/adafactor/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/q_slot_v_1/group_deps" - op: "NoOp" -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/q_slot_v_1/group_deps_1" - op: "NoOp" -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/q/adafactor/square/parallel_0/Square" - op: "Square" - input: "while_loop/while/Exit_30" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/q/adafactor/scalar_add/parallel_0/add/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0000000031710769e-30 - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/q/adafactor/scalar_add/parallel_0/add" - op: "AddV2" - input: "decoder/block_000/layer_001/EncDecAttention/q/adafactor/square/parallel_0/Square" - input: "decoder/block_000/layer_001/EncDecAttention/q/adafactor/scalar_add/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/q/adafactor/square_1/parallel_0/Square" - op: "Square" - input: "decoder/block_000/layer_001/EncDecAttention/q_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/q/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/q/adafactor/reduce_mean/reduce/parallel_0/Sum" - op: "Sum" - input: "decoder/block_000/layer_001/EncDecAttention/q/adafactor/square_1/parallel_0/Square" - input: "decoder/block_000/layer_001/EncDecAttention/q/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/q/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.000244140625 - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/q/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_001/EncDecAttention/q/adafactor/reduce_mean/reduce/parallel_0/Sum" - input: "decoder/block_000/layer_001/EncDecAttention/q/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/q/adafactor/sqrt/parallel_0/Sqrt" - op: "Sqrt" - input: "decoder/block_000/layer_001/EncDecAttention/q/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/q/adafactor/add_1/parallel_0/Maximum" - op: "Maximum" - input: "decoder/block_000/layer_001/EncDecAttention/q/adafactor/sqrt/parallel_0/Sqrt" - input: "decoder/block_000/layer_001/EncDecAttention/q/adafactor/maximum/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/q/adafactor/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_001/EncDecAttention/q/adafactor/add_1/parallel_0/Maximum" - input: "mul_4" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/q/adafactor/scalar_mul_1/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_001/EncDecAttention/q_slot_v/read" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/q/adafactor/scalar_mul_2/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_001/EncDecAttention/q/adafactor/scalar_add/parallel_0/add" - input: "decoder/block_000/layer_001/EncDecAttention/q/adafactor/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/q/adafactor/add_1_1/parallel_0/Add" - op: "Add" - input: "decoder/block_000/layer_001/EncDecAttention/q/adafactor/scalar_mul_1/parallel_0/mul" - input: "decoder/block_000/layer_001/EncDecAttention/q/adafactor/scalar_mul_2/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/q/adafactor/rsqrt/parallel_0/Rsqrt" - op: "Rsqrt" - input: "decoder/block_000/layer_001/EncDecAttention/q/adafactor/add_1_1/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/q/adafactor/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/Exit_30" - input: "decoder/block_000/layer_001/EncDecAttention/q/adafactor/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/q/adafactor/square_2/parallel_0/Square" - op: "Square" - input: "decoder/block_000/layer_001/EncDecAttention/q/adafactor/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/q/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/q/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - op: "Sum" - input: "decoder/block_000/layer_001/EncDecAttention/q/adafactor/square_2/parallel_0/Square" - input: "decoder/block_000/layer_001/EncDecAttention/q/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/q/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.000244140625 - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/q/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_001/EncDecAttention/q/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - input: "decoder/block_000/layer_001/EncDecAttention/q/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/q/adafactor/sqrt_1/parallel_0/Sqrt" - op: "Sqrt" - input: "decoder/block_000/layer_001/EncDecAttention/q/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/q/adafactor/scalar_mul_3/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/q/adafactor/scalar_mul_3/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_001/EncDecAttention/q/adafactor/sqrt_1/parallel_0/Sqrt" - input: "decoder/block_000/layer_001/EncDecAttention/q/adafactor/scalar_mul_3/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/q/adafactor/add_2/parallel_0/Maximum" - op: "Maximum" - input: "decoder/block_000/layer_001/EncDecAttention/q/adafactor/maximum_1/Const" - input: "decoder/block_000/layer_001/EncDecAttention/q/adafactor/scalar_mul_3/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/q/adafactor/reciprocal/parallel_0/Reciprocal" - op: "Reciprocal" - input: "decoder/block_000/layer_001/EncDecAttention/q/adafactor/add_2/parallel_0/Maximum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/q/adafactor/einsum_1/parallel_0/Mul" - op: "Mul" - input: "decoder/block_000/layer_001/EncDecAttention/q/adafactor/einsum/parallel_0/Mul" - input: "decoder/block_000/layer_001/EncDecAttention/q/adafactor/reciprocal/parallel_0/Reciprocal" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/q/adafactor/einsum_2/parallel_0/Mul" - op: "Mul" - input: "decoder/block_000/layer_001/EncDecAttention/q/adafactor/einsum_1/parallel_0/Mul" - input: "decoder/block_000/layer_001/EncDecAttention/q/adafactor/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/k_slot_v_1/group_deps" - op: "NoOp" -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/k_slot_v_1/group_deps_1" - op: "NoOp" -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/k/adafactor/square/parallel_0/Square" - op: "Square" - input: "while_loop/while/Exit_31" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/k/adafactor/scalar_add/parallel_0/add/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0000000031710769e-30 - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/k/adafactor/scalar_add/parallel_0/add" - op: "AddV2" - input: "decoder/block_000/layer_001/EncDecAttention/k/adafactor/square/parallel_0/Square" - input: "decoder/block_000/layer_001/EncDecAttention/k/adafactor/scalar_add/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/k/adafactor/square_1/parallel_0/Square" - op: "Square" - input: "decoder/block_000/layer_001/EncDecAttention/k_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/k/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/k/adafactor/reduce_mean/reduce/parallel_0/Sum" - op: "Sum" - input: "decoder/block_000/layer_001/EncDecAttention/k/adafactor/square_1/parallel_0/Square" - input: "decoder/block_000/layer_001/EncDecAttention/k/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/k/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.000244140625 - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/k/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_001/EncDecAttention/k/adafactor/reduce_mean/reduce/parallel_0/Sum" - input: "decoder/block_000/layer_001/EncDecAttention/k/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/k/adafactor/sqrt/parallel_0/Sqrt" - op: "Sqrt" - input: "decoder/block_000/layer_001/EncDecAttention/k/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/k/adafactor/add_1/parallel_0/Maximum" - op: "Maximum" - input: "decoder/block_000/layer_001/EncDecAttention/k/adafactor/sqrt/parallel_0/Sqrt" - input: "decoder/block_000/layer_001/EncDecAttention/k/adafactor/maximum/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/k/adafactor/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_001/EncDecAttention/k/adafactor/add_1/parallel_0/Maximum" - input: "mul_4" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/k/adafactor/scalar_mul_1/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_001/EncDecAttention/k_slot_v/read" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/k/adafactor/scalar_mul_2/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_001/EncDecAttention/k/adafactor/scalar_add/parallel_0/add" - input: "decoder/block_000/layer_001/EncDecAttention/k/adafactor/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/k/adafactor/add_1_1/parallel_0/Add" - op: "Add" - input: "decoder/block_000/layer_001/EncDecAttention/k/adafactor/scalar_mul_1/parallel_0/mul" - input: "decoder/block_000/layer_001/EncDecAttention/k/adafactor/scalar_mul_2/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/k/adafactor/rsqrt/parallel_0/Rsqrt" - op: "Rsqrt" - input: "decoder/block_000/layer_001/EncDecAttention/k/adafactor/add_1_1/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/k/adafactor/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/Exit_31" - input: "decoder/block_000/layer_001/EncDecAttention/k/adafactor/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/k/adafactor/square_2/parallel_0/Square" - op: "Square" - input: "decoder/block_000/layer_001/EncDecAttention/k/adafactor/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/k/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/k/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - op: "Sum" - input: "decoder/block_000/layer_001/EncDecAttention/k/adafactor/square_2/parallel_0/Square" - input: "decoder/block_000/layer_001/EncDecAttention/k/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/k/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.000244140625 - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/k/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_001/EncDecAttention/k/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - input: "decoder/block_000/layer_001/EncDecAttention/k/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/k/adafactor/sqrt_1/parallel_0/Sqrt" - op: "Sqrt" - input: "decoder/block_000/layer_001/EncDecAttention/k/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/k/adafactor/scalar_mul_3/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/k/adafactor/scalar_mul_3/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_001/EncDecAttention/k/adafactor/sqrt_1/parallel_0/Sqrt" - input: "decoder/block_000/layer_001/EncDecAttention/k/adafactor/scalar_mul_3/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/k/adafactor/add_2/parallel_0/Maximum" - op: "Maximum" - input: "decoder/block_000/layer_001/EncDecAttention/k/adafactor/maximum_1/Const" - input: "decoder/block_000/layer_001/EncDecAttention/k/adafactor/scalar_mul_3/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/k/adafactor/reciprocal/parallel_0/Reciprocal" - op: "Reciprocal" - input: "decoder/block_000/layer_001/EncDecAttention/k/adafactor/add_2/parallel_0/Maximum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/k/adafactor/einsum_1/parallel_0/Mul" - op: "Mul" - input: "decoder/block_000/layer_001/EncDecAttention/k/adafactor/einsum/parallel_0/Mul" - input: "decoder/block_000/layer_001/EncDecAttention/k/adafactor/reciprocal/parallel_0/Reciprocal" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/k/adafactor/einsum_2/parallel_0/Mul" - op: "Mul" - input: "decoder/block_000/layer_001/EncDecAttention/k/adafactor/einsum_1/parallel_0/Mul" - input: "decoder/block_000/layer_001/EncDecAttention/k/adafactor/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/v_slot_v_1/group_deps" - op: "NoOp" -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/v_slot_v_1/group_deps_1" - op: "NoOp" -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/v/adafactor/square/parallel_0/Square" - op: "Square" - input: "while_loop/while/Exit_32" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/v/adafactor/scalar_add/parallel_0/add/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0000000031710769e-30 - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/v/adafactor/scalar_add/parallel_0/add" - op: "AddV2" - input: "decoder/block_000/layer_001/EncDecAttention/v/adafactor/square/parallel_0/Square" - input: "decoder/block_000/layer_001/EncDecAttention/v/adafactor/scalar_add/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/v/adafactor/square_1/parallel_0/Square" - op: "Square" - input: "decoder/block_000/layer_001/EncDecAttention/v_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/v/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/v/adafactor/reduce_mean/reduce/parallel_0/Sum" - op: "Sum" - input: "decoder/block_000/layer_001/EncDecAttention/v/adafactor/square_1/parallel_0/Square" - input: "decoder/block_000/layer_001/EncDecAttention/v/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/v/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.000244140625 - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/v/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_001/EncDecAttention/v/adafactor/reduce_mean/reduce/parallel_0/Sum" - input: "decoder/block_000/layer_001/EncDecAttention/v/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/v/adafactor/sqrt/parallel_0/Sqrt" - op: "Sqrt" - input: "decoder/block_000/layer_001/EncDecAttention/v/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/v/adafactor/add_1/parallel_0/Maximum" - op: "Maximum" - input: "decoder/block_000/layer_001/EncDecAttention/v/adafactor/sqrt/parallel_0/Sqrt" - input: "decoder/block_000/layer_001/EncDecAttention/v/adafactor/maximum/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/v/adafactor/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_001/EncDecAttention/v/adafactor/add_1/parallel_0/Maximum" - input: "mul_4" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/v/adafactor/scalar_mul_1/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_001/EncDecAttention/v_slot_v/read" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/v/adafactor/scalar_mul_2/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_001/EncDecAttention/v/adafactor/scalar_add/parallel_0/add" - input: "decoder/block_000/layer_001/EncDecAttention/v/adafactor/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/v/adafactor/add_1_1/parallel_0/Add" - op: "Add" - input: "decoder/block_000/layer_001/EncDecAttention/v/adafactor/scalar_mul_1/parallel_0/mul" - input: "decoder/block_000/layer_001/EncDecAttention/v/adafactor/scalar_mul_2/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/v/adafactor/rsqrt/parallel_0/Rsqrt" - op: "Rsqrt" - input: "decoder/block_000/layer_001/EncDecAttention/v/adafactor/add_1_1/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/v/adafactor/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/Exit_32" - input: "decoder/block_000/layer_001/EncDecAttention/v/adafactor/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/v/adafactor/square_2/parallel_0/Square" - op: "Square" - input: "decoder/block_000/layer_001/EncDecAttention/v/adafactor/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/v/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/v/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - op: "Sum" - input: "decoder/block_000/layer_001/EncDecAttention/v/adafactor/square_2/parallel_0/Square" - input: "decoder/block_000/layer_001/EncDecAttention/v/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/v/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.000244140625 - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/v/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_001/EncDecAttention/v/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - input: "decoder/block_000/layer_001/EncDecAttention/v/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/v/adafactor/sqrt_1/parallel_0/Sqrt" - op: "Sqrt" - input: "decoder/block_000/layer_001/EncDecAttention/v/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/v/adafactor/scalar_mul_3/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/v/adafactor/scalar_mul_3/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_001/EncDecAttention/v/adafactor/sqrt_1/parallel_0/Sqrt" - input: "decoder/block_000/layer_001/EncDecAttention/v/adafactor/scalar_mul_3/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/v/adafactor/add_2/parallel_0/Maximum" - op: "Maximum" - input: "decoder/block_000/layer_001/EncDecAttention/v/adafactor/maximum_1/Const" - input: "decoder/block_000/layer_001/EncDecAttention/v/adafactor/scalar_mul_3/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/v/adafactor/reciprocal/parallel_0/Reciprocal" - op: "Reciprocal" - input: "decoder/block_000/layer_001/EncDecAttention/v/adafactor/add_2/parallel_0/Maximum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/v/adafactor/einsum_1/parallel_0/Mul" - op: "Mul" - input: "decoder/block_000/layer_001/EncDecAttention/v/adafactor/einsum/parallel_0/Mul" - input: "decoder/block_000/layer_001/EncDecAttention/v/adafactor/reciprocal/parallel_0/Reciprocal" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/v/adafactor/einsum_2/parallel_0/Mul" - op: "Mul" - input: "decoder/block_000/layer_001/EncDecAttention/v/adafactor/einsum_1/parallel_0/Mul" - input: "decoder/block_000/layer_001/EncDecAttention/v/adafactor/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/o_slot_v_1/group_deps" - op: "NoOp" -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/o_slot_v_1/group_deps_1" - op: "NoOp" -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/o/adafactor/square/parallel_0/Square" - op: "Square" - input: "while_loop/while/Exit_33" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/o/adafactor/scalar_add/parallel_0/add/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0000000031710769e-30 - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/o/adafactor/scalar_add/parallel_0/add" - op: "AddV2" - input: "decoder/block_000/layer_001/EncDecAttention/o/adafactor/square/parallel_0/Square" - input: "decoder/block_000/layer_001/EncDecAttention/o/adafactor/scalar_add/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/o/adafactor/square_1/parallel_0/Square" - op: "Square" - input: "decoder/block_000/layer_001/EncDecAttention/o_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/o/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/o/adafactor/reduce_mean/reduce/parallel_0/Sum" - op: "Sum" - input: "decoder/block_000/layer_001/EncDecAttention/o/adafactor/square_1/parallel_0/Square" - input: "decoder/block_000/layer_001/EncDecAttention/o/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/o/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.000244140625 - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/o/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_001/EncDecAttention/o/adafactor/reduce_mean/reduce/parallel_0/Sum" - input: "decoder/block_000/layer_001/EncDecAttention/o/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/o/adafactor/sqrt/parallel_0/Sqrt" - op: "Sqrt" - input: "decoder/block_000/layer_001/EncDecAttention/o/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/o/adafactor/add_1/parallel_0/Maximum" - op: "Maximum" - input: "decoder/block_000/layer_001/EncDecAttention/o/adafactor/sqrt/parallel_0/Sqrt" - input: "decoder/block_000/layer_001/EncDecAttention/o/adafactor/maximum/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/o/adafactor/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_001/EncDecAttention/o/adafactor/add_1/parallel_0/Maximum" - input: "mul_4" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/o/adafactor/scalar_mul_1/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_001/EncDecAttention/o_slot_v/read" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/o/adafactor/scalar_mul_2/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_001/EncDecAttention/o/adafactor/scalar_add/parallel_0/add" - input: "decoder/block_000/layer_001/EncDecAttention/o/adafactor/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/o/adafactor/add_1_1/parallel_0/Add" - op: "Add" - input: "decoder/block_000/layer_001/EncDecAttention/o/adafactor/scalar_mul_1/parallel_0/mul" - input: "decoder/block_000/layer_001/EncDecAttention/o/adafactor/scalar_mul_2/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/o/adafactor/rsqrt/parallel_0/Rsqrt" - op: "Rsqrt" - input: "decoder/block_000/layer_001/EncDecAttention/o/adafactor/add_1_1/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/o/adafactor/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/Exit_33" - input: "decoder/block_000/layer_001/EncDecAttention/o/adafactor/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/o/adafactor/square_2/parallel_0/Square" - op: "Square" - input: "decoder/block_000/layer_001/EncDecAttention/o/adafactor/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/o/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/o/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - op: "Sum" - input: "decoder/block_000/layer_001/EncDecAttention/o/adafactor/square_2/parallel_0/Square" - input: "decoder/block_000/layer_001/EncDecAttention/o/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/o/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.000244140625 - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/o/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_001/EncDecAttention/o/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - input: "decoder/block_000/layer_001/EncDecAttention/o/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/o/adafactor/sqrt_1/parallel_0/Sqrt" - op: "Sqrt" - input: "decoder/block_000/layer_001/EncDecAttention/o/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/o/adafactor/scalar_mul_3/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/o/adafactor/scalar_mul_3/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_001/EncDecAttention/o/adafactor/sqrt_1/parallel_0/Sqrt" - input: "decoder/block_000/layer_001/EncDecAttention/o/adafactor/scalar_mul_3/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/o/adafactor/add_2/parallel_0/Maximum" - op: "Maximum" - input: "decoder/block_000/layer_001/EncDecAttention/o/adafactor/maximum_1/Const" - input: "decoder/block_000/layer_001/EncDecAttention/o/adafactor/scalar_mul_3/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/o/adafactor/reciprocal/parallel_0/Reciprocal" - op: "Reciprocal" - input: "decoder/block_000/layer_001/EncDecAttention/o/adafactor/add_2/parallel_0/Maximum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/o/adafactor/einsum_1/parallel_0/Mul" - op: "Mul" - input: "decoder/block_000/layer_001/EncDecAttention/o/adafactor/einsum/parallel_0/Mul" - input: "decoder/block_000/layer_001/EncDecAttention/o/adafactor/reciprocal/parallel_0/Reciprocal" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_001/EncDecAttention/o/adafactor/einsum_2/parallel_0/Mul" - op: "Mul" - input: "decoder/block_000/layer_001/EncDecAttention/o/adafactor/einsum_1/parallel_0/Mul" - input: "decoder/block_000/layer_001/EncDecAttention/o/adafactor/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/rms_norm/scale_slot_v_1/group_deps" - op: "NoOp" -} -node { - name: "decoder/block_000/layer_002/rms_norm/scale_slot_v_1/group_deps_1" - op: "NoOp" -} -node { - name: "decoder/block_000/layer_002/rms_norm/scale/adafactor/square/parallel_0/Square" - op: "Square" - input: "while_loop/while/Exit_34" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/rms_norm/scale/adafactor/scalar_add/parallel_0/add/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0000000031710769e-30 - } - } - } -} -node { - name: "decoder/block_000/layer_002/rms_norm/scale/adafactor/scalar_add/parallel_0/add" - op: "AddV2" - input: "decoder/block_000/layer_002/rms_norm/scale/adafactor/square/parallel_0/Square" - input: "decoder/block_000/layer_002/rms_norm/scale/adafactor/scalar_add/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/rms_norm/scale/adafactor/square_1/parallel_0/Square" - op: "Square" - input: "decoder/block_000/layer_002/rms_norm/scale_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/rms_norm/scale/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "decoder/block_000/layer_002/rms_norm/scale/adafactor/reduce_mean/reduce/parallel_0/Sum" - op: "Sum" - input: "decoder/block_000/layer_002/rms_norm/scale/adafactor/square_1/parallel_0/Square" - input: "decoder/block_000/layer_002/rms_norm/scale/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "decoder/block_000/layer_002/rms_norm/scale/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.03125 - } - } - } -} -node { - name: "decoder/block_000/layer_002/rms_norm/scale/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_002/rms_norm/scale/adafactor/reduce_mean/reduce/parallel_0/Sum" - input: "decoder/block_000/layer_002/rms_norm/scale/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/rms_norm/scale/adafactor/sqrt/parallel_0/Sqrt" - op: "Sqrt" - input: "decoder/block_000/layer_002/rms_norm/scale/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/rms_norm/scale/adafactor/add_1/parallel_0/Maximum" - op: "Maximum" - input: "decoder/block_000/layer_002/rms_norm/scale/adafactor/sqrt/parallel_0/Sqrt" - input: "decoder/block_000/layer_002/rms_norm/scale/adafactor/maximum/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/rms_norm/scale/adafactor/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_002/rms_norm/scale/adafactor/add_1/parallel_0/Maximum" - input: "mul_4" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/rms_norm/scale/adafactor/scalar_mul_1/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_002/rms_norm/scale_slot_v/read" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/rms_norm/scale/adafactor/scalar_mul_2/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_002/rms_norm/scale/adafactor/scalar_add/parallel_0/add" - input: "decoder/block_000/layer_002/rms_norm/scale/adafactor/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/rms_norm/scale/adafactor/add_1_1/parallel_0/Add" - op: "Add" - input: "decoder/block_000/layer_002/rms_norm/scale/adafactor/scalar_mul_1/parallel_0/mul" - input: "decoder/block_000/layer_002/rms_norm/scale/adafactor/scalar_mul_2/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/rms_norm/scale/adafactor/rsqrt/parallel_0/Rsqrt" - op: "Rsqrt" - input: "decoder/block_000/layer_002/rms_norm/scale/adafactor/add_1_1/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/rms_norm/scale/adafactor/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/Exit_34" - input: "decoder/block_000/layer_002/rms_norm/scale/adafactor/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/rms_norm/scale/adafactor/square_2/parallel_0/Square" - op: "Square" - input: "decoder/block_000/layer_002/rms_norm/scale/adafactor/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/rms_norm/scale/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "decoder/block_000/layer_002/rms_norm/scale/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - op: "Sum" - input: "decoder/block_000/layer_002/rms_norm/scale/adafactor/square_2/parallel_0/Square" - input: "decoder/block_000/layer_002/rms_norm/scale/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "decoder/block_000/layer_002/rms_norm/scale/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.03125 - } - } - } -} -node { - name: "decoder/block_000/layer_002/rms_norm/scale/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_002/rms_norm/scale/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - input: "decoder/block_000/layer_002/rms_norm/scale/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/rms_norm/scale/adafactor/sqrt_1/parallel_0/Sqrt" - op: "Sqrt" - input: "decoder/block_000/layer_002/rms_norm/scale/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/rms_norm/scale/adafactor/scalar_mul_3/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_000/layer_002/rms_norm/scale/adafactor/scalar_mul_3/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_002/rms_norm/scale/adafactor/sqrt_1/parallel_0/Sqrt" - input: "decoder/block_000/layer_002/rms_norm/scale/adafactor/scalar_mul_3/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/rms_norm/scale/adafactor/add_2/parallel_0/Maximum" - op: "Maximum" - input: "decoder/block_000/layer_002/rms_norm/scale/adafactor/maximum_1/Const" - input: "decoder/block_000/layer_002/rms_norm/scale/adafactor/scalar_mul_3/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/rms_norm/scale/adafactor/reciprocal/parallel_0/Reciprocal" - op: "Reciprocal" - input: "decoder/block_000/layer_002/rms_norm/scale/adafactor/add_2/parallel_0/Maximum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/rms_norm/scale/adafactor/einsum_1/parallel_0/Mul" - op: "Mul" - input: "decoder/block_000/layer_002/rms_norm/scale/adafactor/einsum/parallel_0/Mul" - input: "decoder/block_000/layer_002/rms_norm/scale/adafactor/reciprocal/parallel_0/Reciprocal" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/rms_norm/scale/adafactor/einsum_2/parallel_0/Mul" - op: "Mul" - input: "decoder/block_000/layer_002/rms_norm/scale/adafactor/einsum_1/parallel_0/Mul" - input: "decoder/block_000/layer_002/rms_norm/scale/adafactor/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slot_v_1/group_deps" - op: "NoOp" -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slot_v_1/group_deps_1" - op: "NoOp" -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/adafactor/square/parallel_0/Square" - op: "Square" - input: "while_loop/while/Exit_35" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/adafactor/scalar_add/parallel_0/add/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0000000031710769e-30 - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/adafactor/scalar_add/parallel_0/add" - op: "AddV2" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/adafactor/square/parallel_0/Square" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/adafactor/scalar_add/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/adafactor/square_1/parallel_0/Square" - op: "Square" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/adafactor/reduce_mean/reduce/parallel_0/Sum" - op: "Sum" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/adafactor/square_1/parallel_0/Square" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.00048828125 - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/adafactor/reduce_mean/reduce/parallel_0/Sum" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/adafactor/sqrt/parallel_0/Sqrt" - op: "Sqrt" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/adafactor/add_1/parallel_0/Maximum" - op: "Maximum" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/adafactor/sqrt/parallel_0/Sqrt" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/adafactor/maximum/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/adafactor/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/adafactor/add_1/parallel_0/Maximum" - input: "mul_4" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/adafactor/scalar_mul_1/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slot_v/read" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/adafactor/scalar_mul_2/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/adafactor/scalar_add/parallel_0/add" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/adafactor/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/adafactor/add_1_1/parallel_0/Add" - op: "Add" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/adafactor/scalar_mul_1/parallel_0/mul" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/adafactor/scalar_mul_2/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/adafactor/rsqrt/parallel_0/Rsqrt" - op: "Rsqrt" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/adafactor/add_1_1/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/adafactor/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/Exit_35" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/adafactor/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/adafactor/square_2/parallel_0/Square" - op: "Square" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/adafactor/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - op: "Sum" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/adafactor/square_2/parallel_0/Square" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.00048828125 - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/adafactor/sqrt_1/parallel_0/Sqrt" - op: "Sqrt" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/adafactor/scalar_mul_3/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/adafactor/scalar_mul_3/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/adafactor/sqrt_1/parallel_0/Sqrt" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/adafactor/scalar_mul_3/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/adafactor/add_2/parallel_0/Maximum" - op: "Maximum" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/adafactor/maximum_1/Const" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/adafactor/scalar_mul_3/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/adafactor/reciprocal/parallel_0/Reciprocal" - op: "Reciprocal" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/adafactor/add_2/parallel_0/Maximum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/adafactor/einsum_1/parallel_0/Mul" - op: "Mul" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/adafactor/einsum/parallel_0/Mul" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/adafactor/reciprocal/parallel_0/Reciprocal" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/adafactor/einsum_2/parallel_0/Mul" - op: "Mul" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/adafactor/einsum_1/parallel_0/Mul" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/adafactor/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slot_v_1/group_deps" - op: "NoOp" -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slot_v_1/group_deps_1" - op: "NoOp" -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/adafactor/square/parallel_0/Square" - op: "Square" - input: "while_loop/while/Exit_36" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/adafactor/scalar_add/parallel_0/add/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0000000031710769e-30 - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/adafactor/scalar_add/parallel_0/add" - op: "AddV2" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/adafactor/square/parallel_0/Square" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/adafactor/scalar_add/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/adafactor/square_1/parallel_0/Square" - op: "Square" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/adafactor/reduce_mean/reduce/parallel_0/Sum" - op: "Sum" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/adafactor/square_1/parallel_0/Square" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.00048828125 - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/adafactor/reduce_mean/reduce/parallel_0/Sum" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/adafactor/sqrt/parallel_0/Sqrt" - op: "Sqrt" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/adafactor/add_1/parallel_0/Maximum" - op: "Maximum" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/adafactor/sqrt/parallel_0/Sqrt" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/adafactor/maximum/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/adafactor/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/adafactor/add_1/parallel_0/Maximum" - input: "mul_4" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/adafactor/scalar_mul_1/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slot_v/read" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/adafactor/scalar_mul_2/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/adafactor/scalar_add/parallel_0/add" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/adafactor/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/adafactor/add_1_1/parallel_0/Add" - op: "Add" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/adafactor/scalar_mul_1/parallel_0/mul" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/adafactor/scalar_mul_2/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/adafactor/rsqrt/parallel_0/Rsqrt" - op: "Rsqrt" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/adafactor/add_1_1/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/adafactor/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/Exit_36" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/adafactor/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/adafactor/square_2/parallel_0/Square" - op: "Square" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/adafactor/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - op: "Sum" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/adafactor/square_2/parallel_0/Square" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.00048828125 - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/adafactor/sqrt_1/parallel_0/Sqrt" - op: "Sqrt" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/adafactor/scalar_mul_3/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/adafactor/scalar_mul_3/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/adafactor/sqrt_1/parallel_0/Sqrt" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/adafactor/scalar_mul_3/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/adafactor/add_2/parallel_0/Maximum" - op: "Maximum" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/adafactor/maximum_1/Const" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/adafactor/scalar_mul_3/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/adafactor/reciprocal/parallel_0/Reciprocal" - op: "Reciprocal" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/adafactor/add_2/parallel_0/Maximum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/adafactor/einsum_1/parallel_0/Mul" - op: "Mul" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/adafactor/einsum/parallel_0/Mul" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/adafactor/reciprocal/parallel_0/Reciprocal" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/adafactor/einsum_2/parallel_0/Mul" - op: "Mul" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/adafactor/einsum_1/parallel_0/Mul" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/adafactor/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wo/kernel_slot_v_1/group_deps" - op: "NoOp" -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wo/kernel_slot_v_1/group_deps_1" - op: "NoOp" -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/adafactor/square/parallel_0/Square" - op: "Square" - input: "while_loop/while/Exit_37" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/adafactor/scalar_add/parallel_0/add/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0000000031710769e-30 - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/adafactor/scalar_add/parallel_0/add" - op: "AddV2" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/adafactor/square/parallel_0/Square" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/adafactor/scalar_add/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/adafactor/square_1/parallel_0/Square" - op: "Square" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/adafactor/reduce_mean/reduce/parallel_0/Sum" - op: "Sum" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/adafactor/square_1/parallel_0/Square" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.00048828125 - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/adafactor/reduce_mean/reduce/parallel_0/Sum" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/adafactor/sqrt/parallel_0/Sqrt" - op: "Sqrt" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/adafactor/add_1/parallel_0/Maximum" - op: "Maximum" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/adafactor/sqrt/parallel_0/Sqrt" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/adafactor/maximum/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/adafactor/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/adafactor/add_1/parallel_0/Maximum" - input: "mul_4" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/adafactor/scalar_mul_1/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel_slot_v/read" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/adafactor/scalar_mul_2/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/adafactor/scalar_add/parallel_0/add" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/adafactor/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/adafactor/add_1_1/parallel_0/Add" - op: "Add" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/adafactor/scalar_mul_1/parallel_0/mul" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/adafactor/scalar_mul_2/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/adafactor/rsqrt/parallel_0/Rsqrt" - op: "Rsqrt" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/adafactor/add_1_1/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/adafactor/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/Exit_37" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/adafactor/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/adafactor/square_2/parallel_0/Square" - op: "Square" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/adafactor/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - op: "Sum" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/adafactor/square_2/parallel_0/Square" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.00048828125 - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/adafactor/sqrt_1/parallel_0/Sqrt" - op: "Sqrt" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/adafactor/scalar_mul_3/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/adafactor/scalar_mul_3/parallel_0/mul" - op: "Mul" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/adafactor/sqrt_1/parallel_0/Sqrt" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/adafactor/scalar_mul_3/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/adafactor/add_2/parallel_0/Maximum" - op: "Maximum" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/adafactor/maximum_1/Const" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/adafactor/scalar_mul_3/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/adafactor/reciprocal/parallel_0/Reciprocal" - op: "Reciprocal" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/adafactor/add_2/parallel_0/Maximum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/adafactor/einsum_1/parallel_0/Mul" - op: "Mul" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/adafactor/einsum/parallel_0/Mul" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/adafactor/reciprocal/parallel_0/Reciprocal" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/adafactor/einsum_2/parallel_0/Mul" - op: "Mul" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/adafactor/einsum_1/parallel_0/Mul" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/adafactor/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/rms_norm/scale_slot_v_1/group_deps" - op: "NoOp" -} -node { - name: "decoder/block_001/layer_000/rms_norm/scale_slot_v_1/group_deps_1" - op: "NoOp" -} -node { - name: "decoder/block_001/layer_000/rms_norm/scale/adafactor/square/parallel_0/Square" - op: "Square" - input: "while_loop/while/Exit_38" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/rms_norm/scale/adafactor/scalar_add/parallel_0/add/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0000000031710769e-30 - } - } - } -} -node { - name: "decoder/block_001/layer_000/rms_norm/scale/adafactor/scalar_add/parallel_0/add" - op: "AddV2" - input: "decoder/block_001/layer_000/rms_norm/scale/adafactor/square/parallel_0/Square" - input: "decoder/block_001/layer_000/rms_norm/scale/adafactor/scalar_add/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/rms_norm/scale/adafactor/square_1/parallel_0/Square" - op: "Square" - input: "decoder/block_001/layer_000/rms_norm/scale_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/rms_norm/scale/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "decoder/block_001/layer_000/rms_norm/scale/adafactor/reduce_mean/reduce/parallel_0/Sum" - op: "Sum" - input: "decoder/block_001/layer_000/rms_norm/scale/adafactor/square_1/parallel_0/Square" - input: "decoder/block_001/layer_000/rms_norm/scale/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "decoder/block_001/layer_000/rms_norm/scale/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.03125 - } - } - } -} -node { - name: "decoder/block_001/layer_000/rms_norm/scale/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_000/rms_norm/scale/adafactor/reduce_mean/reduce/parallel_0/Sum" - input: "decoder/block_001/layer_000/rms_norm/scale/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/rms_norm/scale/adafactor/sqrt/parallel_0/Sqrt" - op: "Sqrt" - input: "decoder/block_001/layer_000/rms_norm/scale/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/rms_norm/scale/adafactor/add_1/parallel_0/Maximum" - op: "Maximum" - input: "decoder/block_001/layer_000/rms_norm/scale/adafactor/sqrt/parallel_0/Sqrt" - input: "decoder/block_001/layer_000/rms_norm/scale/adafactor/maximum/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/rms_norm/scale/adafactor/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_000/rms_norm/scale/adafactor/add_1/parallel_0/Maximum" - input: "mul_4" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/rms_norm/scale/adafactor/scalar_mul_1/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_000/rms_norm/scale_slot_v/read" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/rms_norm/scale/adafactor/scalar_mul_2/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_000/rms_norm/scale/adafactor/scalar_add/parallel_0/add" - input: "decoder/block_001/layer_000/rms_norm/scale/adafactor/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/rms_norm/scale/adafactor/add_1_1/parallel_0/Add" - op: "Add" - input: "decoder/block_001/layer_000/rms_norm/scale/adafactor/scalar_mul_1/parallel_0/mul" - input: "decoder/block_001/layer_000/rms_norm/scale/adafactor/scalar_mul_2/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/rms_norm/scale/adafactor/rsqrt/parallel_0/Rsqrt" - op: "Rsqrt" - input: "decoder/block_001/layer_000/rms_norm/scale/adafactor/add_1_1/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/rms_norm/scale/adafactor/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/Exit_38" - input: "decoder/block_001/layer_000/rms_norm/scale/adafactor/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/rms_norm/scale/adafactor/square_2/parallel_0/Square" - op: "Square" - input: "decoder/block_001/layer_000/rms_norm/scale/adafactor/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/rms_norm/scale/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "decoder/block_001/layer_000/rms_norm/scale/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - op: "Sum" - input: "decoder/block_001/layer_000/rms_norm/scale/adafactor/square_2/parallel_0/Square" - input: "decoder/block_001/layer_000/rms_norm/scale/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "decoder/block_001/layer_000/rms_norm/scale/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.03125 - } - } - } -} -node { - name: "decoder/block_001/layer_000/rms_norm/scale/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_000/rms_norm/scale/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - input: "decoder/block_001/layer_000/rms_norm/scale/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/rms_norm/scale/adafactor/sqrt_1/parallel_0/Sqrt" - op: "Sqrt" - input: "decoder/block_001/layer_000/rms_norm/scale/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/rms_norm/scale/adafactor/scalar_mul_3/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_001/layer_000/rms_norm/scale/adafactor/scalar_mul_3/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_000/rms_norm/scale/adafactor/sqrt_1/parallel_0/Sqrt" - input: "decoder/block_001/layer_000/rms_norm/scale/adafactor/scalar_mul_3/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/rms_norm/scale/adafactor/add_2/parallel_0/Maximum" - op: "Maximum" - input: "decoder/block_001/layer_000/rms_norm/scale/adafactor/maximum_1/Const" - input: "decoder/block_001/layer_000/rms_norm/scale/adafactor/scalar_mul_3/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/rms_norm/scale/adafactor/reciprocal/parallel_0/Reciprocal" - op: "Reciprocal" - input: "decoder/block_001/layer_000/rms_norm/scale/adafactor/add_2/parallel_0/Maximum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/rms_norm/scale/adafactor/einsum_1/parallel_0/Mul" - op: "Mul" - input: "decoder/block_001/layer_000/rms_norm/scale/adafactor/einsum/parallel_0/Mul" - input: "decoder/block_001/layer_000/rms_norm/scale/adafactor/reciprocal/parallel_0/Reciprocal" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/rms_norm/scale/adafactor/einsum_2/parallel_0/Mul" - op: "Mul" - input: "decoder/block_001/layer_000/rms_norm/scale/adafactor/einsum_1/parallel_0/Mul" - input: "decoder/block_001/layer_000/rms_norm/scale/adafactor/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/q_slot_v_1/group_deps" - op: "NoOp" -} -node { - name: "decoder/block_001/layer_000/SelfAttention/q_slot_v_1/group_deps_1" - op: "NoOp" -} -node { - name: "decoder/block_001/layer_000/SelfAttention/q/adafactor/square/parallel_0/Square" - op: "Square" - input: "while_loop/while/Exit_39" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/q/adafactor/scalar_add/parallel_0/add/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0000000031710769e-30 - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/q/adafactor/scalar_add/parallel_0/add" - op: "AddV2" - input: "decoder/block_001/layer_000/SelfAttention/q/adafactor/square/parallel_0/Square" - input: "decoder/block_001/layer_000/SelfAttention/q/adafactor/scalar_add/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/q/adafactor/square_1/parallel_0/Square" - op: "Square" - input: "decoder/block_001/layer_000/SelfAttention/q_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/q/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/q/adafactor/reduce_mean/reduce/parallel_0/Sum" - op: "Sum" - input: "decoder/block_001/layer_000/SelfAttention/q/adafactor/square_1/parallel_0/Square" - input: "decoder/block_001/layer_000/SelfAttention/q/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/q/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.000244140625 - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/q/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_000/SelfAttention/q/adafactor/reduce_mean/reduce/parallel_0/Sum" - input: "decoder/block_001/layer_000/SelfAttention/q/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/q/adafactor/sqrt/parallel_0/Sqrt" - op: "Sqrt" - input: "decoder/block_001/layer_000/SelfAttention/q/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/q/adafactor/add_1/parallel_0/Maximum" - op: "Maximum" - input: "decoder/block_001/layer_000/SelfAttention/q/adafactor/sqrt/parallel_0/Sqrt" - input: "decoder/block_001/layer_000/SelfAttention/q/adafactor/maximum/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/q/adafactor/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_000/SelfAttention/q/adafactor/add_1/parallel_0/Maximum" - input: "mul_4" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/q/adafactor/scalar_mul_1/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_000/SelfAttention/q_slot_v/read" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/q/adafactor/scalar_mul_2/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_000/SelfAttention/q/adafactor/scalar_add/parallel_0/add" - input: "decoder/block_001/layer_000/SelfAttention/q/adafactor/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/q/adafactor/add_1_1/parallel_0/Add" - op: "Add" - input: "decoder/block_001/layer_000/SelfAttention/q/adafactor/scalar_mul_1/parallel_0/mul" - input: "decoder/block_001/layer_000/SelfAttention/q/adafactor/scalar_mul_2/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/q/adafactor/rsqrt/parallel_0/Rsqrt" - op: "Rsqrt" - input: "decoder/block_001/layer_000/SelfAttention/q/adafactor/add_1_1/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/q/adafactor/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/Exit_39" - input: "decoder/block_001/layer_000/SelfAttention/q/adafactor/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/q/adafactor/square_2/parallel_0/Square" - op: "Square" - input: "decoder/block_001/layer_000/SelfAttention/q/adafactor/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/q/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/q/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - op: "Sum" - input: "decoder/block_001/layer_000/SelfAttention/q/adafactor/square_2/parallel_0/Square" - input: "decoder/block_001/layer_000/SelfAttention/q/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/q/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.000244140625 - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/q/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_000/SelfAttention/q/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - input: "decoder/block_001/layer_000/SelfAttention/q/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/q/adafactor/sqrt_1/parallel_0/Sqrt" - op: "Sqrt" - input: "decoder/block_001/layer_000/SelfAttention/q/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/q/adafactor/scalar_mul_3/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/q/adafactor/scalar_mul_3/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_000/SelfAttention/q/adafactor/sqrt_1/parallel_0/Sqrt" - input: "decoder/block_001/layer_000/SelfAttention/q/adafactor/scalar_mul_3/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/q/adafactor/add_2/parallel_0/Maximum" - op: "Maximum" - input: "decoder/block_001/layer_000/SelfAttention/q/adafactor/maximum_1/Const" - input: "decoder/block_001/layer_000/SelfAttention/q/adafactor/scalar_mul_3/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/q/adafactor/reciprocal/parallel_0/Reciprocal" - op: "Reciprocal" - input: "decoder/block_001/layer_000/SelfAttention/q/adafactor/add_2/parallel_0/Maximum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/q/adafactor/einsum_1/parallel_0/Mul" - op: "Mul" - input: "decoder/block_001/layer_000/SelfAttention/q/adafactor/einsum/parallel_0/Mul" - input: "decoder/block_001/layer_000/SelfAttention/q/adafactor/reciprocal/parallel_0/Reciprocal" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/q/adafactor/einsum_2/parallel_0/Mul" - op: "Mul" - input: "decoder/block_001/layer_000/SelfAttention/q/adafactor/einsum_1/parallel_0/Mul" - input: "decoder/block_001/layer_000/SelfAttention/q/adafactor/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/k_slot_v_1/group_deps" - op: "NoOp" -} -node { - name: "decoder/block_001/layer_000/SelfAttention/k_slot_v_1/group_deps_1" - op: "NoOp" -} -node { - name: "decoder/block_001/layer_000/SelfAttention/k/adafactor/square/parallel_0/Square" - op: "Square" - input: "while_loop/while/Exit_40" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/k/adafactor/scalar_add/parallel_0/add/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0000000031710769e-30 - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/k/adafactor/scalar_add/parallel_0/add" - op: "AddV2" - input: "decoder/block_001/layer_000/SelfAttention/k/adafactor/square/parallel_0/Square" - input: "decoder/block_001/layer_000/SelfAttention/k/adafactor/scalar_add/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/k/adafactor/square_1/parallel_0/Square" - op: "Square" - input: "decoder/block_001/layer_000/SelfAttention/k_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/k/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/k/adafactor/reduce_mean/reduce/parallel_0/Sum" - op: "Sum" - input: "decoder/block_001/layer_000/SelfAttention/k/adafactor/square_1/parallel_0/Square" - input: "decoder/block_001/layer_000/SelfAttention/k/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/k/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.000244140625 - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/k/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_000/SelfAttention/k/adafactor/reduce_mean/reduce/parallel_0/Sum" - input: "decoder/block_001/layer_000/SelfAttention/k/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/k/adafactor/sqrt/parallel_0/Sqrt" - op: "Sqrt" - input: "decoder/block_001/layer_000/SelfAttention/k/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/k/adafactor/add_1/parallel_0/Maximum" - op: "Maximum" - input: "decoder/block_001/layer_000/SelfAttention/k/adafactor/sqrt/parallel_0/Sqrt" - input: "decoder/block_001/layer_000/SelfAttention/k/adafactor/maximum/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/k/adafactor/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_000/SelfAttention/k/adafactor/add_1/parallel_0/Maximum" - input: "mul_4" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/k/adafactor/scalar_mul_1/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_000/SelfAttention/k_slot_v/read" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/k/adafactor/scalar_mul_2/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_000/SelfAttention/k/adafactor/scalar_add/parallel_0/add" - input: "decoder/block_001/layer_000/SelfAttention/k/adafactor/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/k/adafactor/add_1_1/parallel_0/Add" - op: "Add" - input: "decoder/block_001/layer_000/SelfAttention/k/adafactor/scalar_mul_1/parallel_0/mul" - input: "decoder/block_001/layer_000/SelfAttention/k/adafactor/scalar_mul_2/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/k/adafactor/rsqrt/parallel_0/Rsqrt" - op: "Rsqrt" - input: "decoder/block_001/layer_000/SelfAttention/k/adafactor/add_1_1/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/k/adafactor/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/Exit_40" - input: "decoder/block_001/layer_000/SelfAttention/k/adafactor/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/k/adafactor/square_2/parallel_0/Square" - op: "Square" - input: "decoder/block_001/layer_000/SelfAttention/k/adafactor/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/k/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/k/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - op: "Sum" - input: "decoder/block_001/layer_000/SelfAttention/k/adafactor/square_2/parallel_0/Square" - input: "decoder/block_001/layer_000/SelfAttention/k/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/k/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.000244140625 - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/k/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_000/SelfAttention/k/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - input: "decoder/block_001/layer_000/SelfAttention/k/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/k/adafactor/sqrt_1/parallel_0/Sqrt" - op: "Sqrt" - input: "decoder/block_001/layer_000/SelfAttention/k/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/k/adafactor/scalar_mul_3/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/k/adafactor/scalar_mul_3/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_000/SelfAttention/k/adafactor/sqrt_1/parallel_0/Sqrt" - input: "decoder/block_001/layer_000/SelfAttention/k/adafactor/scalar_mul_3/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/k/adafactor/add_2/parallel_0/Maximum" - op: "Maximum" - input: "decoder/block_001/layer_000/SelfAttention/k/adafactor/maximum_1/Const" - input: "decoder/block_001/layer_000/SelfAttention/k/adafactor/scalar_mul_3/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/k/adafactor/reciprocal/parallel_0/Reciprocal" - op: "Reciprocal" - input: "decoder/block_001/layer_000/SelfAttention/k/adafactor/add_2/parallel_0/Maximum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/k/adafactor/einsum_1/parallel_0/Mul" - op: "Mul" - input: "decoder/block_001/layer_000/SelfAttention/k/adafactor/einsum/parallel_0/Mul" - input: "decoder/block_001/layer_000/SelfAttention/k/adafactor/reciprocal/parallel_0/Reciprocal" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/k/adafactor/einsum_2/parallel_0/Mul" - op: "Mul" - input: "decoder/block_001/layer_000/SelfAttention/k/adafactor/einsum_1/parallel_0/Mul" - input: "decoder/block_001/layer_000/SelfAttention/k/adafactor/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/v_slot_v_1/group_deps" - op: "NoOp" -} -node { - name: "decoder/block_001/layer_000/SelfAttention/v_slot_v_1/group_deps_1" - op: "NoOp" -} -node { - name: "decoder/block_001/layer_000/SelfAttention/v/adafactor/square/parallel_0/Square" - op: "Square" - input: "while_loop/while/Exit_41" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/v/adafactor/scalar_add/parallel_0/add/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0000000031710769e-30 - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/v/adafactor/scalar_add/parallel_0/add" - op: "AddV2" - input: "decoder/block_001/layer_000/SelfAttention/v/adafactor/square/parallel_0/Square" - input: "decoder/block_001/layer_000/SelfAttention/v/adafactor/scalar_add/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/v/adafactor/square_1/parallel_0/Square" - op: "Square" - input: "decoder/block_001/layer_000/SelfAttention/v_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/v/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/v/adafactor/reduce_mean/reduce/parallel_0/Sum" - op: "Sum" - input: "decoder/block_001/layer_000/SelfAttention/v/adafactor/square_1/parallel_0/Square" - input: "decoder/block_001/layer_000/SelfAttention/v/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/v/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.000244140625 - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/v/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_000/SelfAttention/v/adafactor/reduce_mean/reduce/parallel_0/Sum" - input: "decoder/block_001/layer_000/SelfAttention/v/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/v/adafactor/sqrt/parallel_0/Sqrt" - op: "Sqrt" - input: "decoder/block_001/layer_000/SelfAttention/v/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/v/adafactor/add_1/parallel_0/Maximum" - op: "Maximum" - input: "decoder/block_001/layer_000/SelfAttention/v/adafactor/sqrt/parallel_0/Sqrt" - input: "decoder/block_001/layer_000/SelfAttention/v/adafactor/maximum/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/v/adafactor/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_000/SelfAttention/v/adafactor/add_1/parallel_0/Maximum" - input: "mul_4" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/v/adafactor/scalar_mul_1/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_000/SelfAttention/v_slot_v/read" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/v/adafactor/scalar_mul_2/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_000/SelfAttention/v/adafactor/scalar_add/parallel_0/add" - input: "decoder/block_001/layer_000/SelfAttention/v/adafactor/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/v/adafactor/add_1_1/parallel_0/Add" - op: "Add" - input: "decoder/block_001/layer_000/SelfAttention/v/adafactor/scalar_mul_1/parallel_0/mul" - input: "decoder/block_001/layer_000/SelfAttention/v/adafactor/scalar_mul_2/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/v/adafactor/rsqrt/parallel_0/Rsqrt" - op: "Rsqrt" - input: "decoder/block_001/layer_000/SelfAttention/v/adafactor/add_1_1/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/v/adafactor/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/Exit_41" - input: "decoder/block_001/layer_000/SelfAttention/v/adafactor/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/v/adafactor/square_2/parallel_0/Square" - op: "Square" - input: "decoder/block_001/layer_000/SelfAttention/v/adafactor/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/v/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/v/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - op: "Sum" - input: "decoder/block_001/layer_000/SelfAttention/v/adafactor/square_2/parallel_0/Square" - input: "decoder/block_001/layer_000/SelfAttention/v/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/v/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.000244140625 - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/v/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_000/SelfAttention/v/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - input: "decoder/block_001/layer_000/SelfAttention/v/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/v/adafactor/sqrt_1/parallel_0/Sqrt" - op: "Sqrt" - input: "decoder/block_001/layer_000/SelfAttention/v/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/v/adafactor/scalar_mul_3/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/v/adafactor/scalar_mul_3/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_000/SelfAttention/v/adafactor/sqrt_1/parallel_0/Sqrt" - input: "decoder/block_001/layer_000/SelfAttention/v/adafactor/scalar_mul_3/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/v/adafactor/add_2/parallel_0/Maximum" - op: "Maximum" - input: "decoder/block_001/layer_000/SelfAttention/v/adafactor/maximum_1/Const" - input: "decoder/block_001/layer_000/SelfAttention/v/adafactor/scalar_mul_3/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/v/adafactor/reciprocal/parallel_0/Reciprocal" - op: "Reciprocal" - input: "decoder/block_001/layer_000/SelfAttention/v/adafactor/add_2/parallel_0/Maximum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/v/adafactor/einsum_1/parallel_0/Mul" - op: "Mul" - input: "decoder/block_001/layer_000/SelfAttention/v/adafactor/einsum/parallel_0/Mul" - input: "decoder/block_001/layer_000/SelfAttention/v/adafactor/reciprocal/parallel_0/Reciprocal" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/v/adafactor/einsum_2/parallel_0/Mul" - op: "Mul" - input: "decoder/block_001/layer_000/SelfAttention/v/adafactor/einsum_1/parallel_0/Mul" - input: "decoder/block_001/layer_000/SelfAttention/v/adafactor/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/o_slot_v_1/group_deps" - op: "NoOp" -} -node { - name: "decoder/block_001/layer_000/SelfAttention/o_slot_v_1/group_deps_1" - op: "NoOp" -} -node { - name: "decoder/block_001/layer_000/SelfAttention/o/adafactor/square/parallel_0/Square" - op: "Square" - input: "while_loop/while/Exit_42" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/o/adafactor/scalar_add/parallel_0/add/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0000000031710769e-30 - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/o/adafactor/scalar_add/parallel_0/add" - op: "AddV2" - input: "decoder/block_001/layer_000/SelfAttention/o/adafactor/square/parallel_0/Square" - input: "decoder/block_001/layer_000/SelfAttention/o/adafactor/scalar_add/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/o/adafactor/square_1/parallel_0/Square" - op: "Square" - input: "decoder/block_001/layer_000/SelfAttention/o_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/o/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/o/adafactor/reduce_mean/reduce/parallel_0/Sum" - op: "Sum" - input: "decoder/block_001/layer_000/SelfAttention/o/adafactor/square_1/parallel_0/Square" - input: "decoder/block_001/layer_000/SelfAttention/o/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/o/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.000244140625 - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/o/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_000/SelfAttention/o/adafactor/reduce_mean/reduce/parallel_0/Sum" - input: "decoder/block_001/layer_000/SelfAttention/o/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/o/adafactor/sqrt/parallel_0/Sqrt" - op: "Sqrt" - input: "decoder/block_001/layer_000/SelfAttention/o/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/o/adafactor/add_1/parallel_0/Maximum" - op: "Maximum" - input: "decoder/block_001/layer_000/SelfAttention/o/adafactor/sqrt/parallel_0/Sqrt" - input: "decoder/block_001/layer_000/SelfAttention/o/adafactor/maximum/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/o/adafactor/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_000/SelfAttention/o/adafactor/add_1/parallel_0/Maximum" - input: "mul_4" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/o/adafactor/scalar_mul_1/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_000/SelfAttention/o_slot_v/read" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/o/adafactor/scalar_mul_2/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_000/SelfAttention/o/adafactor/scalar_add/parallel_0/add" - input: "decoder/block_001/layer_000/SelfAttention/o/adafactor/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/o/adafactor/add_1_1/parallel_0/Add" - op: "Add" - input: "decoder/block_001/layer_000/SelfAttention/o/adafactor/scalar_mul_1/parallel_0/mul" - input: "decoder/block_001/layer_000/SelfAttention/o/adafactor/scalar_mul_2/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/o/adafactor/rsqrt/parallel_0/Rsqrt" - op: "Rsqrt" - input: "decoder/block_001/layer_000/SelfAttention/o/adafactor/add_1_1/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/o/adafactor/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/Exit_42" - input: "decoder/block_001/layer_000/SelfAttention/o/adafactor/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/o/adafactor/square_2/parallel_0/Square" - op: "Square" - input: "decoder/block_001/layer_000/SelfAttention/o/adafactor/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/o/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/o/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - op: "Sum" - input: "decoder/block_001/layer_000/SelfAttention/o/adafactor/square_2/parallel_0/Square" - input: "decoder/block_001/layer_000/SelfAttention/o/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/o/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.000244140625 - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/o/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_000/SelfAttention/o/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - input: "decoder/block_001/layer_000/SelfAttention/o/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/o/adafactor/sqrt_1/parallel_0/Sqrt" - op: "Sqrt" - input: "decoder/block_001/layer_000/SelfAttention/o/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/o/adafactor/scalar_mul_3/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/o/adafactor/scalar_mul_3/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_000/SelfAttention/o/adafactor/sqrt_1/parallel_0/Sqrt" - input: "decoder/block_001/layer_000/SelfAttention/o/adafactor/scalar_mul_3/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/o/adafactor/add_2/parallel_0/Maximum" - op: "Maximum" - input: "decoder/block_001/layer_000/SelfAttention/o/adafactor/maximum_1/Const" - input: "decoder/block_001/layer_000/SelfAttention/o/adafactor/scalar_mul_3/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/o/adafactor/reciprocal/parallel_0/Reciprocal" - op: "Reciprocal" - input: "decoder/block_001/layer_000/SelfAttention/o/adafactor/add_2/parallel_0/Maximum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/o/adafactor/einsum_1/parallel_0/Mul" - op: "Mul" - input: "decoder/block_001/layer_000/SelfAttention/o/adafactor/einsum/parallel_0/Mul" - input: "decoder/block_001/layer_000/SelfAttention/o/adafactor/reciprocal/parallel_0/Reciprocal" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_000/SelfAttention/o/adafactor/einsum_2/parallel_0/Mul" - op: "Mul" - input: "decoder/block_001/layer_000/SelfAttention/o/adafactor/einsum_1/parallel_0/Mul" - input: "decoder/block_001/layer_000/SelfAttention/o/adafactor/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/rms_norm/scale_slot_v_1/group_deps" - op: "NoOp" -} -node { - name: "decoder/block_001/layer_001/rms_norm/scale_slot_v_1/group_deps_1" - op: "NoOp" -} -node { - name: "decoder/block_001/layer_001/rms_norm/scale/adafactor/square/parallel_0/Square" - op: "Square" - input: "while_loop/while/Exit_43" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/rms_norm/scale/adafactor/scalar_add/parallel_0/add/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0000000031710769e-30 - } - } - } -} -node { - name: "decoder/block_001/layer_001/rms_norm/scale/adafactor/scalar_add/parallel_0/add" - op: "AddV2" - input: "decoder/block_001/layer_001/rms_norm/scale/adafactor/square/parallel_0/Square" - input: "decoder/block_001/layer_001/rms_norm/scale/adafactor/scalar_add/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/rms_norm/scale/adafactor/square_1/parallel_0/Square" - op: "Square" - input: "decoder/block_001/layer_001/rms_norm/scale_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/rms_norm/scale/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "decoder/block_001/layer_001/rms_norm/scale/adafactor/reduce_mean/reduce/parallel_0/Sum" - op: "Sum" - input: "decoder/block_001/layer_001/rms_norm/scale/adafactor/square_1/parallel_0/Square" - input: "decoder/block_001/layer_001/rms_norm/scale/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "decoder/block_001/layer_001/rms_norm/scale/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.03125 - } - } - } -} -node { - name: "decoder/block_001/layer_001/rms_norm/scale/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_001/rms_norm/scale/adafactor/reduce_mean/reduce/parallel_0/Sum" - input: "decoder/block_001/layer_001/rms_norm/scale/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/rms_norm/scale/adafactor/sqrt/parallel_0/Sqrt" - op: "Sqrt" - input: "decoder/block_001/layer_001/rms_norm/scale/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/rms_norm/scale/adafactor/add_1/parallel_0/Maximum" - op: "Maximum" - input: "decoder/block_001/layer_001/rms_norm/scale/adafactor/sqrt/parallel_0/Sqrt" - input: "decoder/block_001/layer_001/rms_norm/scale/adafactor/maximum/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/rms_norm/scale/adafactor/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_001/rms_norm/scale/adafactor/add_1/parallel_0/Maximum" - input: "mul_4" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/rms_norm/scale/adafactor/scalar_mul_1/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_001/rms_norm/scale_slot_v/read" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/rms_norm/scale/adafactor/scalar_mul_2/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_001/rms_norm/scale/adafactor/scalar_add/parallel_0/add" - input: "decoder/block_001/layer_001/rms_norm/scale/adafactor/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/rms_norm/scale/adafactor/add_1_1/parallel_0/Add" - op: "Add" - input: "decoder/block_001/layer_001/rms_norm/scale/adafactor/scalar_mul_1/parallel_0/mul" - input: "decoder/block_001/layer_001/rms_norm/scale/adafactor/scalar_mul_2/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/rms_norm/scale/adafactor/rsqrt/parallel_0/Rsqrt" - op: "Rsqrt" - input: "decoder/block_001/layer_001/rms_norm/scale/adafactor/add_1_1/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/rms_norm/scale/adafactor/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/Exit_43" - input: "decoder/block_001/layer_001/rms_norm/scale/adafactor/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/rms_norm/scale/adafactor/square_2/parallel_0/Square" - op: "Square" - input: "decoder/block_001/layer_001/rms_norm/scale/adafactor/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/rms_norm/scale/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "decoder/block_001/layer_001/rms_norm/scale/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - op: "Sum" - input: "decoder/block_001/layer_001/rms_norm/scale/adafactor/square_2/parallel_0/Square" - input: "decoder/block_001/layer_001/rms_norm/scale/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "decoder/block_001/layer_001/rms_norm/scale/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.03125 - } - } - } -} -node { - name: "decoder/block_001/layer_001/rms_norm/scale/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_001/rms_norm/scale/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - input: "decoder/block_001/layer_001/rms_norm/scale/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/rms_norm/scale/adafactor/sqrt_1/parallel_0/Sqrt" - op: "Sqrt" - input: "decoder/block_001/layer_001/rms_norm/scale/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/rms_norm/scale/adafactor/scalar_mul_3/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_001/layer_001/rms_norm/scale/adafactor/scalar_mul_3/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_001/rms_norm/scale/adafactor/sqrt_1/parallel_0/Sqrt" - input: "decoder/block_001/layer_001/rms_norm/scale/adafactor/scalar_mul_3/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/rms_norm/scale/adafactor/add_2/parallel_0/Maximum" - op: "Maximum" - input: "decoder/block_001/layer_001/rms_norm/scale/adafactor/maximum_1/Const" - input: "decoder/block_001/layer_001/rms_norm/scale/adafactor/scalar_mul_3/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/rms_norm/scale/adafactor/reciprocal/parallel_0/Reciprocal" - op: "Reciprocal" - input: "decoder/block_001/layer_001/rms_norm/scale/adafactor/add_2/parallel_0/Maximum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/rms_norm/scale/adafactor/einsum_1/parallel_0/Mul" - op: "Mul" - input: "decoder/block_001/layer_001/rms_norm/scale/adafactor/einsum/parallel_0/Mul" - input: "decoder/block_001/layer_001/rms_norm/scale/adafactor/reciprocal/parallel_0/Reciprocal" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/rms_norm/scale/adafactor/einsum_2/parallel_0/Mul" - op: "Mul" - input: "decoder/block_001/layer_001/rms_norm/scale/adafactor/einsum_1/parallel_0/Mul" - input: "decoder/block_001/layer_001/rms_norm/scale/adafactor/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/q_slot_v_1/group_deps" - op: "NoOp" -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/q_slot_v_1/group_deps_1" - op: "NoOp" -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/q/adafactor/square/parallel_0/Square" - op: "Square" - input: "while_loop/while/Exit_44" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/q/adafactor/scalar_add/parallel_0/add/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0000000031710769e-30 - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/q/adafactor/scalar_add/parallel_0/add" - op: "AddV2" - input: "decoder/block_001/layer_001/EncDecAttention/q/adafactor/square/parallel_0/Square" - input: "decoder/block_001/layer_001/EncDecAttention/q/adafactor/scalar_add/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/q/adafactor/square_1/parallel_0/Square" - op: "Square" - input: "decoder/block_001/layer_001/EncDecAttention/q_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/q/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/q/adafactor/reduce_mean/reduce/parallel_0/Sum" - op: "Sum" - input: "decoder/block_001/layer_001/EncDecAttention/q/adafactor/square_1/parallel_0/Square" - input: "decoder/block_001/layer_001/EncDecAttention/q/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/q/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.000244140625 - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/q/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_001/EncDecAttention/q/adafactor/reduce_mean/reduce/parallel_0/Sum" - input: "decoder/block_001/layer_001/EncDecAttention/q/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/q/adafactor/sqrt/parallel_0/Sqrt" - op: "Sqrt" - input: "decoder/block_001/layer_001/EncDecAttention/q/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/q/adafactor/add_1/parallel_0/Maximum" - op: "Maximum" - input: "decoder/block_001/layer_001/EncDecAttention/q/adafactor/sqrt/parallel_0/Sqrt" - input: "decoder/block_001/layer_001/EncDecAttention/q/adafactor/maximum/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/q/adafactor/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_001/EncDecAttention/q/adafactor/add_1/parallel_0/Maximum" - input: "mul_4" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/q/adafactor/scalar_mul_1/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_001/EncDecAttention/q_slot_v/read" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/q/adafactor/scalar_mul_2/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_001/EncDecAttention/q/adafactor/scalar_add/parallel_0/add" - input: "decoder/block_001/layer_001/EncDecAttention/q/adafactor/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/q/adafactor/add_1_1/parallel_0/Add" - op: "Add" - input: "decoder/block_001/layer_001/EncDecAttention/q/adafactor/scalar_mul_1/parallel_0/mul" - input: "decoder/block_001/layer_001/EncDecAttention/q/adafactor/scalar_mul_2/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/q/adafactor/rsqrt/parallel_0/Rsqrt" - op: "Rsqrt" - input: "decoder/block_001/layer_001/EncDecAttention/q/adafactor/add_1_1/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/q/adafactor/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/Exit_44" - input: "decoder/block_001/layer_001/EncDecAttention/q/adafactor/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/q/adafactor/square_2/parallel_0/Square" - op: "Square" - input: "decoder/block_001/layer_001/EncDecAttention/q/adafactor/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/q/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/q/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - op: "Sum" - input: "decoder/block_001/layer_001/EncDecAttention/q/adafactor/square_2/parallel_0/Square" - input: "decoder/block_001/layer_001/EncDecAttention/q/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/q/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.000244140625 - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/q/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_001/EncDecAttention/q/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - input: "decoder/block_001/layer_001/EncDecAttention/q/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/q/adafactor/sqrt_1/parallel_0/Sqrt" - op: "Sqrt" - input: "decoder/block_001/layer_001/EncDecAttention/q/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/q/adafactor/scalar_mul_3/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/q/adafactor/scalar_mul_3/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_001/EncDecAttention/q/adafactor/sqrt_1/parallel_0/Sqrt" - input: "decoder/block_001/layer_001/EncDecAttention/q/adafactor/scalar_mul_3/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/q/adafactor/add_2/parallel_0/Maximum" - op: "Maximum" - input: "decoder/block_001/layer_001/EncDecAttention/q/adafactor/maximum_1/Const" - input: "decoder/block_001/layer_001/EncDecAttention/q/adafactor/scalar_mul_3/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/q/adafactor/reciprocal/parallel_0/Reciprocal" - op: "Reciprocal" - input: "decoder/block_001/layer_001/EncDecAttention/q/adafactor/add_2/parallel_0/Maximum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/q/adafactor/einsum_1/parallel_0/Mul" - op: "Mul" - input: "decoder/block_001/layer_001/EncDecAttention/q/adafactor/einsum/parallel_0/Mul" - input: "decoder/block_001/layer_001/EncDecAttention/q/adafactor/reciprocal/parallel_0/Reciprocal" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/q/adafactor/einsum_2/parallel_0/Mul" - op: "Mul" - input: "decoder/block_001/layer_001/EncDecAttention/q/adafactor/einsum_1/parallel_0/Mul" - input: "decoder/block_001/layer_001/EncDecAttention/q/adafactor/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/k_slot_v_1/group_deps" - op: "NoOp" -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/k_slot_v_1/group_deps_1" - op: "NoOp" -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/k/adafactor/square/parallel_0/Square" - op: "Square" - input: "while_loop/while/Exit_45" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/k/adafactor/scalar_add/parallel_0/add/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0000000031710769e-30 - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/k/adafactor/scalar_add/parallel_0/add" - op: "AddV2" - input: "decoder/block_001/layer_001/EncDecAttention/k/adafactor/square/parallel_0/Square" - input: "decoder/block_001/layer_001/EncDecAttention/k/adafactor/scalar_add/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/k/adafactor/square_1/parallel_0/Square" - op: "Square" - input: "decoder/block_001/layer_001/EncDecAttention/k_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/k/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/k/adafactor/reduce_mean/reduce/parallel_0/Sum" - op: "Sum" - input: "decoder/block_001/layer_001/EncDecAttention/k/adafactor/square_1/parallel_0/Square" - input: "decoder/block_001/layer_001/EncDecAttention/k/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/k/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.000244140625 - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/k/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_001/EncDecAttention/k/adafactor/reduce_mean/reduce/parallel_0/Sum" - input: "decoder/block_001/layer_001/EncDecAttention/k/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/k/adafactor/sqrt/parallel_0/Sqrt" - op: "Sqrt" - input: "decoder/block_001/layer_001/EncDecAttention/k/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/k/adafactor/add_1/parallel_0/Maximum" - op: "Maximum" - input: "decoder/block_001/layer_001/EncDecAttention/k/adafactor/sqrt/parallel_0/Sqrt" - input: "decoder/block_001/layer_001/EncDecAttention/k/adafactor/maximum/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/k/adafactor/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_001/EncDecAttention/k/adafactor/add_1/parallel_0/Maximum" - input: "mul_4" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/k/adafactor/scalar_mul_1/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_001/EncDecAttention/k_slot_v/read" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/k/adafactor/scalar_mul_2/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_001/EncDecAttention/k/adafactor/scalar_add/parallel_0/add" - input: "decoder/block_001/layer_001/EncDecAttention/k/adafactor/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/k/adafactor/add_1_1/parallel_0/Add" - op: "Add" - input: "decoder/block_001/layer_001/EncDecAttention/k/adafactor/scalar_mul_1/parallel_0/mul" - input: "decoder/block_001/layer_001/EncDecAttention/k/adafactor/scalar_mul_2/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/k/adafactor/rsqrt/parallel_0/Rsqrt" - op: "Rsqrt" - input: "decoder/block_001/layer_001/EncDecAttention/k/adafactor/add_1_1/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/k/adafactor/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/Exit_45" - input: "decoder/block_001/layer_001/EncDecAttention/k/adafactor/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/k/adafactor/square_2/parallel_0/Square" - op: "Square" - input: "decoder/block_001/layer_001/EncDecAttention/k/adafactor/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/k/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/k/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - op: "Sum" - input: "decoder/block_001/layer_001/EncDecAttention/k/adafactor/square_2/parallel_0/Square" - input: "decoder/block_001/layer_001/EncDecAttention/k/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/k/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.000244140625 - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/k/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_001/EncDecAttention/k/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - input: "decoder/block_001/layer_001/EncDecAttention/k/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/k/adafactor/sqrt_1/parallel_0/Sqrt" - op: "Sqrt" - input: "decoder/block_001/layer_001/EncDecAttention/k/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/k/adafactor/scalar_mul_3/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/k/adafactor/scalar_mul_3/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_001/EncDecAttention/k/adafactor/sqrt_1/parallel_0/Sqrt" - input: "decoder/block_001/layer_001/EncDecAttention/k/adafactor/scalar_mul_3/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/k/adafactor/add_2/parallel_0/Maximum" - op: "Maximum" - input: "decoder/block_001/layer_001/EncDecAttention/k/adafactor/maximum_1/Const" - input: "decoder/block_001/layer_001/EncDecAttention/k/adafactor/scalar_mul_3/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/k/adafactor/reciprocal/parallel_0/Reciprocal" - op: "Reciprocal" - input: "decoder/block_001/layer_001/EncDecAttention/k/adafactor/add_2/parallel_0/Maximum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/k/adafactor/einsum_1/parallel_0/Mul" - op: "Mul" - input: "decoder/block_001/layer_001/EncDecAttention/k/adafactor/einsum/parallel_0/Mul" - input: "decoder/block_001/layer_001/EncDecAttention/k/adafactor/reciprocal/parallel_0/Reciprocal" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/k/adafactor/einsum_2/parallel_0/Mul" - op: "Mul" - input: "decoder/block_001/layer_001/EncDecAttention/k/adafactor/einsum_1/parallel_0/Mul" - input: "decoder/block_001/layer_001/EncDecAttention/k/adafactor/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/v_slot_v_1/group_deps" - op: "NoOp" -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/v_slot_v_1/group_deps_1" - op: "NoOp" -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/v/adafactor/square/parallel_0/Square" - op: "Square" - input: "while_loop/while/Exit_46" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/v/adafactor/scalar_add/parallel_0/add/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0000000031710769e-30 - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/v/adafactor/scalar_add/parallel_0/add" - op: "AddV2" - input: "decoder/block_001/layer_001/EncDecAttention/v/adafactor/square/parallel_0/Square" - input: "decoder/block_001/layer_001/EncDecAttention/v/adafactor/scalar_add/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/v/adafactor/square_1/parallel_0/Square" - op: "Square" - input: "decoder/block_001/layer_001/EncDecAttention/v_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/v/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/v/adafactor/reduce_mean/reduce/parallel_0/Sum" - op: "Sum" - input: "decoder/block_001/layer_001/EncDecAttention/v/adafactor/square_1/parallel_0/Square" - input: "decoder/block_001/layer_001/EncDecAttention/v/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/v/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.000244140625 - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/v/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_001/EncDecAttention/v/adafactor/reduce_mean/reduce/parallel_0/Sum" - input: "decoder/block_001/layer_001/EncDecAttention/v/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/v/adafactor/sqrt/parallel_0/Sqrt" - op: "Sqrt" - input: "decoder/block_001/layer_001/EncDecAttention/v/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/v/adafactor/add_1/parallel_0/Maximum" - op: "Maximum" - input: "decoder/block_001/layer_001/EncDecAttention/v/adafactor/sqrt/parallel_0/Sqrt" - input: "decoder/block_001/layer_001/EncDecAttention/v/adafactor/maximum/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/v/adafactor/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_001/EncDecAttention/v/adafactor/add_1/parallel_0/Maximum" - input: "mul_4" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/v/adafactor/scalar_mul_1/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_001/EncDecAttention/v_slot_v/read" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/v/adafactor/scalar_mul_2/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_001/EncDecAttention/v/adafactor/scalar_add/parallel_0/add" - input: "decoder/block_001/layer_001/EncDecAttention/v/adafactor/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/v/adafactor/add_1_1/parallel_0/Add" - op: "Add" - input: "decoder/block_001/layer_001/EncDecAttention/v/adafactor/scalar_mul_1/parallel_0/mul" - input: "decoder/block_001/layer_001/EncDecAttention/v/adafactor/scalar_mul_2/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/v/adafactor/rsqrt/parallel_0/Rsqrt" - op: "Rsqrt" - input: "decoder/block_001/layer_001/EncDecAttention/v/adafactor/add_1_1/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/v/adafactor/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/Exit_46" - input: "decoder/block_001/layer_001/EncDecAttention/v/adafactor/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/v/adafactor/square_2/parallel_0/Square" - op: "Square" - input: "decoder/block_001/layer_001/EncDecAttention/v/adafactor/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/v/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/v/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - op: "Sum" - input: "decoder/block_001/layer_001/EncDecAttention/v/adafactor/square_2/parallel_0/Square" - input: "decoder/block_001/layer_001/EncDecAttention/v/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/v/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.000244140625 - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/v/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_001/EncDecAttention/v/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - input: "decoder/block_001/layer_001/EncDecAttention/v/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/v/adafactor/sqrt_1/parallel_0/Sqrt" - op: "Sqrt" - input: "decoder/block_001/layer_001/EncDecAttention/v/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/v/adafactor/scalar_mul_3/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/v/adafactor/scalar_mul_3/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_001/EncDecAttention/v/adafactor/sqrt_1/parallel_0/Sqrt" - input: "decoder/block_001/layer_001/EncDecAttention/v/adafactor/scalar_mul_3/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/v/adafactor/add_2/parallel_0/Maximum" - op: "Maximum" - input: "decoder/block_001/layer_001/EncDecAttention/v/adafactor/maximum_1/Const" - input: "decoder/block_001/layer_001/EncDecAttention/v/adafactor/scalar_mul_3/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/v/adafactor/reciprocal/parallel_0/Reciprocal" - op: "Reciprocal" - input: "decoder/block_001/layer_001/EncDecAttention/v/adafactor/add_2/parallel_0/Maximum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/v/adafactor/einsum_1/parallel_0/Mul" - op: "Mul" - input: "decoder/block_001/layer_001/EncDecAttention/v/adafactor/einsum/parallel_0/Mul" - input: "decoder/block_001/layer_001/EncDecAttention/v/adafactor/reciprocal/parallel_0/Reciprocal" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/v/adafactor/einsum_2/parallel_0/Mul" - op: "Mul" - input: "decoder/block_001/layer_001/EncDecAttention/v/adafactor/einsum_1/parallel_0/Mul" - input: "decoder/block_001/layer_001/EncDecAttention/v/adafactor/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/o_slot_v_1/group_deps" - op: "NoOp" -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/o_slot_v_1/group_deps_1" - op: "NoOp" -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/o/adafactor/square/parallel_0/Square" - op: "Square" - input: "while_loop/while/Exit_47" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/o/adafactor/scalar_add/parallel_0/add/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0000000031710769e-30 - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/o/adafactor/scalar_add/parallel_0/add" - op: "AddV2" - input: "decoder/block_001/layer_001/EncDecAttention/o/adafactor/square/parallel_0/Square" - input: "decoder/block_001/layer_001/EncDecAttention/o/adafactor/scalar_add/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/o/adafactor/square_1/parallel_0/Square" - op: "Square" - input: "decoder/block_001/layer_001/EncDecAttention/o_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/o/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/o/adafactor/reduce_mean/reduce/parallel_0/Sum" - op: "Sum" - input: "decoder/block_001/layer_001/EncDecAttention/o/adafactor/square_1/parallel_0/Square" - input: "decoder/block_001/layer_001/EncDecAttention/o/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/o/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.000244140625 - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/o/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_001/EncDecAttention/o/adafactor/reduce_mean/reduce/parallel_0/Sum" - input: "decoder/block_001/layer_001/EncDecAttention/o/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/o/adafactor/sqrt/parallel_0/Sqrt" - op: "Sqrt" - input: "decoder/block_001/layer_001/EncDecAttention/o/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/o/adafactor/add_1/parallel_0/Maximum" - op: "Maximum" - input: "decoder/block_001/layer_001/EncDecAttention/o/adafactor/sqrt/parallel_0/Sqrt" - input: "decoder/block_001/layer_001/EncDecAttention/o/adafactor/maximum/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/o/adafactor/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_001/EncDecAttention/o/adafactor/add_1/parallel_0/Maximum" - input: "mul_4" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/o/adafactor/scalar_mul_1/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_001/EncDecAttention/o_slot_v/read" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/o/adafactor/scalar_mul_2/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_001/EncDecAttention/o/adafactor/scalar_add/parallel_0/add" - input: "decoder/block_001/layer_001/EncDecAttention/o/adafactor/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/o/adafactor/add_1_1/parallel_0/Add" - op: "Add" - input: "decoder/block_001/layer_001/EncDecAttention/o/adafactor/scalar_mul_1/parallel_0/mul" - input: "decoder/block_001/layer_001/EncDecAttention/o/adafactor/scalar_mul_2/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/o/adafactor/rsqrt/parallel_0/Rsqrt" - op: "Rsqrt" - input: "decoder/block_001/layer_001/EncDecAttention/o/adafactor/add_1_1/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/o/adafactor/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/Exit_47" - input: "decoder/block_001/layer_001/EncDecAttention/o/adafactor/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/o/adafactor/square_2/parallel_0/Square" - op: "Square" - input: "decoder/block_001/layer_001/EncDecAttention/o/adafactor/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/o/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/o/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - op: "Sum" - input: "decoder/block_001/layer_001/EncDecAttention/o/adafactor/square_2/parallel_0/Square" - input: "decoder/block_001/layer_001/EncDecAttention/o/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/o/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.000244140625 - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/o/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_001/EncDecAttention/o/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - input: "decoder/block_001/layer_001/EncDecAttention/o/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/o/adafactor/sqrt_1/parallel_0/Sqrt" - op: "Sqrt" - input: "decoder/block_001/layer_001/EncDecAttention/o/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/o/adafactor/scalar_mul_3/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/o/adafactor/scalar_mul_3/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_001/EncDecAttention/o/adafactor/sqrt_1/parallel_0/Sqrt" - input: "decoder/block_001/layer_001/EncDecAttention/o/adafactor/scalar_mul_3/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/o/adafactor/add_2/parallel_0/Maximum" - op: "Maximum" - input: "decoder/block_001/layer_001/EncDecAttention/o/adafactor/maximum_1/Const" - input: "decoder/block_001/layer_001/EncDecAttention/o/adafactor/scalar_mul_3/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/o/adafactor/reciprocal/parallel_0/Reciprocal" - op: "Reciprocal" - input: "decoder/block_001/layer_001/EncDecAttention/o/adafactor/add_2/parallel_0/Maximum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/o/adafactor/einsum_1/parallel_0/Mul" - op: "Mul" - input: "decoder/block_001/layer_001/EncDecAttention/o/adafactor/einsum/parallel_0/Mul" - input: "decoder/block_001/layer_001/EncDecAttention/o/adafactor/reciprocal/parallel_0/Reciprocal" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_001/EncDecAttention/o/adafactor/einsum_2/parallel_0/Mul" - op: "Mul" - input: "decoder/block_001/layer_001/EncDecAttention/o/adafactor/einsum_1/parallel_0/Mul" - input: "decoder/block_001/layer_001/EncDecAttention/o/adafactor/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/rms_norm/scale_slot_v_1/group_deps" - op: "NoOp" -} -node { - name: "decoder/block_001/layer_002/rms_norm/scale_slot_v_1/group_deps_1" - op: "NoOp" -} -node { - name: "decoder/block_001/layer_002/rms_norm/scale/adafactor/square/parallel_0/Square" - op: "Square" - input: "while_loop/while/Exit_48" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/rms_norm/scale/adafactor/scalar_add/parallel_0/add/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0000000031710769e-30 - } - } - } -} -node { - name: "decoder/block_001/layer_002/rms_norm/scale/adafactor/scalar_add/parallel_0/add" - op: "AddV2" - input: "decoder/block_001/layer_002/rms_norm/scale/adafactor/square/parallel_0/Square" - input: "decoder/block_001/layer_002/rms_norm/scale/adafactor/scalar_add/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/rms_norm/scale/adafactor/square_1/parallel_0/Square" - op: "Square" - input: "decoder/block_001/layer_002/rms_norm/scale_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/rms_norm/scale/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "decoder/block_001/layer_002/rms_norm/scale/adafactor/reduce_mean/reduce/parallel_0/Sum" - op: "Sum" - input: "decoder/block_001/layer_002/rms_norm/scale/adafactor/square_1/parallel_0/Square" - input: "decoder/block_001/layer_002/rms_norm/scale/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "decoder/block_001/layer_002/rms_norm/scale/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.03125 - } - } - } -} -node { - name: "decoder/block_001/layer_002/rms_norm/scale/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_002/rms_norm/scale/adafactor/reduce_mean/reduce/parallel_0/Sum" - input: "decoder/block_001/layer_002/rms_norm/scale/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/rms_norm/scale/adafactor/sqrt/parallel_0/Sqrt" - op: "Sqrt" - input: "decoder/block_001/layer_002/rms_norm/scale/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/rms_norm/scale/adafactor/add_1/parallel_0/Maximum" - op: "Maximum" - input: "decoder/block_001/layer_002/rms_norm/scale/adafactor/sqrt/parallel_0/Sqrt" - input: "decoder/block_001/layer_002/rms_norm/scale/adafactor/maximum/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/rms_norm/scale/adafactor/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_002/rms_norm/scale/adafactor/add_1/parallel_0/Maximum" - input: "mul_4" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/rms_norm/scale/adafactor/scalar_mul_1/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_002/rms_norm/scale_slot_v/read" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/rms_norm/scale/adafactor/scalar_mul_2/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_002/rms_norm/scale/adafactor/scalar_add/parallel_0/add" - input: "decoder/block_001/layer_002/rms_norm/scale/adafactor/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/rms_norm/scale/adafactor/add_1_1/parallel_0/Add" - op: "Add" - input: "decoder/block_001/layer_002/rms_norm/scale/adafactor/scalar_mul_1/parallel_0/mul" - input: "decoder/block_001/layer_002/rms_norm/scale/adafactor/scalar_mul_2/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/rms_norm/scale/adafactor/rsqrt/parallel_0/Rsqrt" - op: "Rsqrt" - input: "decoder/block_001/layer_002/rms_norm/scale/adafactor/add_1_1/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/rms_norm/scale/adafactor/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/Exit_48" - input: "decoder/block_001/layer_002/rms_norm/scale/adafactor/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/rms_norm/scale/adafactor/square_2/parallel_0/Square" - op: "Square" - input: "decoder/block_001/layer_002/rms_norm/scale/adafactor/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/rms_norm/scale/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "decoder/block_001/layer_002/rms_norm/scale/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - op: "Sum" - input: "decoder/block_001/layer_002/rms_norm/scale/adafactor/square_2/parallel_0/Square" - input: "decoder/block_001/layer_002/rms_norm/scale/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "decoder/block_001/layer_002/rms_norm/scale/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.03125 - } - } - } -} -node { - name: "decoder/block_001/layer_002/rms_norm/scale/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_002/rms_norm/scale/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - input: "decoder/block_001/layer_002/rms_norm/scale/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/rms_norm/scale/adafactor/sqrt_1/parallel_0/Sqrt" - op: "Sqrt" - input: "decoder/block_001/layer_002/rms_norm/scale/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/rms_norm/scale/adafactor/scalar_mul_3/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_001/layer_002/rms_norm/scale/adafactor/scalar_mul_3/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_002/rms_norm/scale/adafactor/sqrt_1/parallel_0/Sqrt" - input: "decoder/block_001/layer_002/rms_norm/scale/adafactor/scalar_mul_3/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/rms_norm/scale/adafactor/add_2/parallel_0/Maximum" - op: "Maximum" - input: "decoder/block_001/layer_002/rms_norm/scale/adafactor/maximum_1/Const" - input: "decoder/block_001/layer_002/rms_norm/scale/adafactor/scalar_mul_3/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/rms_norm/scale/adafactor/reciprocal/parallel_0/Reciprocal" - op: "Reciprocal" - input: "decoder/block_001/layer_002/rms_norm/scale/adafactor/add_2/parallel_0/Maximum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/rms_norm/scale/adafactor/einsum_1/parallel_0/Mul" - op: "Mul" - input: "decoder/block_001/layer_002/rms_norm/scale/adafactor/einsum/parallel_0/Mul" - input: "decoder/block_001/layer_002/rms_norm/scale/adafactor/reciprocal/parallel_0/Reciprocal" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/rms_norm/scale/adafactor/einsum_2/parallel_0/Mul" - op: "Mul" - input: "decoder/block_001/layer_002/rms_norm/scale/adafactor/einsum_1/parallel_0/Mul" - input: "decoder/block_001/layer_002/rms_norm/scale/adafactor/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slot_v_1/group_deps" - op: "NoOp" -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slot_v_1/group_deps_1" - op: "NoOp" -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/adafactor/square/parallel_0/Square" - op: "Square" - input: "while_loop/while/Exit_49" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/adafactor/scalar_add/parallel_0/add/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0000000031710769e-30 - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/adafactor/scalar_add/parallel_0/add" - op: "AddV2" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/adafactor/square/parallel_0/Square" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/adafactor/scalar_add/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/adafactor/square_1/parallel_0/Square" - op: "Square" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/adafactor/reduce_mean/reduce/parallel_0/Sum" - op: "Sum" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/adafactor/square_1/parallel_0/Square" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.00048828125 - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/adafactor/reduce_mean/reduce/parallel_0/Sum" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/adafactor/sqrt/parallel_0/Sqrt" - op: "Sqrt" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/adafactor/add_1/parallel_0/Maximum" - op: "Maximum" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/adafactor/sqrt/parallel_0/Sqrt" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/adafactor/maximum/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/adafactor/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/adafactor/add_1/parallel_0/Maximum" - input: "mul_4" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/adafactor/scalar_mul_1/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slot_v/read" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/adafactor/scalar_mul_2/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/adafactor/scalar_add/parallel_0/add" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/adafactor/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/adafactor/add_1_1/parallel_0/Add" - op: "Add" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/adafactor/scalar_mul_1/parallel_0/mul" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/adafactor/scalar_mul_2/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/adafactor/rsqrt/parallel_0/Rsqrt" - op: "Rsqrt" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/adafactor/add_1_1/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/adafactor/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/Exit_49" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/adafactor/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/adafactor/square_2/parallel_0/Square" - op: "Square" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/adafactor/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - op: "Sum" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/adafactor/square_2/parallel_0/Square" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.00048828125 - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/adafactor/sqrt_1/parallel_0/Sqrt" - op: "Sqrt" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/adafactor/scalar_mul_3/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/adafactor/scalar_mul_3/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/adafactor/sqrt_1/parallel_0/Sqrt" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/adafactor/scalar_mul_3/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/adafactor/add_2/parallel_0/Maximum" - op: "Maximum" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/adafactor/maximum_1/Const" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/adafactor/scalar_mul_3/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/adafactor/reciprocal/parallel_0/Reciprocal" - op: "Reciprocal" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/adafactor/add_2/parallel_0/Maximum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/adafactor/einsum_1/parallel_0/Mul" - op: "Mul" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/adafactor/einsum/parallel_0/Mul" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/adafactor/reciprocal/parallel_0/Reciprocal" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/adafactor/einsum_2/parallel_0/Mul" - op: "Mul" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/adafactor/einsum_1/parallel_0/Mul" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/adafactor/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slot_v_1/group_deps" - op: "NoOp" -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slot_v_1/group_deps_1" - op: "NoOp" -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/adafactor/square/parallel_0/Square" - op: "Square" - input: "while_loop/while/Exit_50" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/adafactor/scalar_add/parallel_0/add/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0000000031710769e-30 - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/adafactor/scalar_add/parallel_0/add" - op: "AddV2" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/adafactor/square/parallel_0/Square" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/adafactor/scalar_add/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/adafactor/square_1/parallel_0/Square" - op: "Square" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/adafactor/reduce_mean/reduce/parallel_0/Sum" - op: "Sum" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/adafactor/square_1/parallel_0/Square" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.00048828125 - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/adafactor/reduce_mean/reduce/parallel_0/Sum" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/adafactor/sqrt/parallel_0/Sqrt" - op: "Sqrt" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/adafactor/add_1/parallel_0/Maximum" - op: "Maximum" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/adafactor/sqrt/parallel_0/Sqrt" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/adafactor/maximum/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/adafactor/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/adafactor/add_1/parallel_0/Maximum" - input: "mul_4" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/adafactor/scalar_mul_1/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slot_v/read" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/adafactor/scalar_mul_2/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/adafactor/scalar_add/parallel_0/add" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/adafactor/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/adafactor/add_1_1/parallel_0/Add" - op: "Add" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/adafactor/scalar_mul_1/parallel_0/mul" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/adafactor/scalar_mul_2/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/adafactor/rsqrt/parallel_0/Rsqrt" - op: "Rsqrt" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/adafactor/add_1_1/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/adafactor/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/Exit_50" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/adafactor/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/adafactor/square_2/parallel_0/Square" - op: "Square" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/adafactor/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - op: "Sum" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/adafactor/square_2/parallel_0/Square" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.00048828125 - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/adafactor/sqrt_1/parallel_0/Sqrt" - op: "Sqrt" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/adafactor/scalar_mul_3/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/adafactor/scalar_mul_3/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/adafactor/sqrt_1/parallel_0/Sqrt" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/adafactor/scalar_mul_3/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/adafactor/add_2/parallel_0/Maximum" - op: "Maximum" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/adafactor/maximum_1/Const" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/adafactor/scalar_mul_3/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/adafactor/reciprocal/parallel_0/Reciprocal" - op: "Reciprocal" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/adafactor/add_2/parallel_0/Maximum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/adafactor/einsum_1/parallel_0/Mul" - op: "Mul" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/adafactor/einsum/parallel_0/Mul" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/adafactor/reciprocal/parallel_0/Reciprocal" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/adafactor/einsum_2/parallel_0/Mul" - op: "Mul" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/adafactor/einsum_1/parallel_0/Mul" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/adafactor/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wo/kernel_slot_v_1/group_deps" - op: "NoOp" -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wo/kernel_slot_v_1/group_deps_1" - op: "NoOp" -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/adafactor/square/parallel_0/Square" - op: "Square" - input: "while_loop/while/Exit_51" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/adafactor/scalar_add/parallel_0/add/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0000000031710769e-30 - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/adafactor/scalar_add/parallel_0/add" - op: "AddV2" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/adafactor/square/parallel_0/Square" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/adafactor/scalar_add/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/adafactor/square_1/parallel_0/Square" - op: "Square" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/adafactor/reduce_mean/reduce/parallel_0/Sum" - op: "Sum" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/adafactor/square_1/parallel_0/Square" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.00048828125 - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/adafactor/reduce_mean/reduce/parallel_0/Sum" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/adafactor/sqrt/parallel_0/Sqrt" - op: "Sqrt" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/adafactor/add_1/parallel_0/Maximum" - op: "Maximum" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/adafactor/sqrt/parallel_0/Sqrt" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/adafactor/maximum/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/adafactor/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/adafactor/add_1/parallel_0/Maximum" - input: "mul_4" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/adafactor/scalar_mul_1/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel_slot_v/read" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/adafactor/scalar_mul_2/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/adafactor/scalar_add/parallel_0/add" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/adafactor/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/adafactor/add_1_1/parallel_0/Add" - op: "Add" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/adafactor/scalar_mul_1/parallel_0/mul" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/adafactor/scalar_mul_2/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/adafactor/rsqrt/parallel_0/Rsqrt" - op: "Rsqrt" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/adafactor/add_1_1/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/adafactor/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/Exit_51" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/adafactor/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/adafactor/square_2/parallel_0/Square" - op: "Square" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/adafactor/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - op: "Sum" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/adafactor/square_2/parallel_0/Square" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.00048828125 - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/adafactor/sqrt_1/parallel_0/Sqrt" - op: "Sqrt" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/adafactor/scalar_mul_3/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/adafactor/scalar_mul_3/parallel_0/mul" - op: "Mul" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/adafactor/sqrt_1/parallel_0/Sqrt" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/adafactor/scalar_mul_3/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/adafactor/add_2/parallel_0/Maximum" - op: "Maximum" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/adafactor/maximum_1/Const" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/adafactor/scalar_mul_3/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/adafactor/reciprocal/parallel_0/Reciprocal" - op: "Reciprocal" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/adafactor/add_2/parallel_0/Maximum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/adafactor/einsum_1/parallel_0/Mul" - op: "Mul" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/adafactor/einsum/parallel_0/Mul" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/adafactor/reciprocal/parallel_0/Reciprocal" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/adafactor/einsum_2/parallel_0/Mul" - op: "Mul" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/adafactor/einsum_1/parallel_0/Mul" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/adafactor/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/rms_norm/scale_slot_v_1/group_deps" - op: "NoOp" -} -node { - name: "decoder/rms_norm/scale_slot_v_1/group_deps_1" - op: "NoOp" -} -node { - name: "decoder/rms_norm/scale/adafactor/square/parallel_0/Square" - op: "Square" - input: "while_loop/while/Exit_52" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/rms_norm/scale/adafactor/scalar_add/parallel_0/add/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0000000031710769e-30 - } - } - } -} -node { - name: "decoder/rms_norm/scale/adafactor/scalar_add/parallel_0/add" - op: "AddV2" - input: "decoder/rms_norm/scale/adafactor/square/parallel_0/Square" - input: "decoder/rms_norm/scale/adafactor/scalar_add/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/rms_norm/scale/adafactor/square_1/parallel_0/Square" - op: "Square" - input: "decoder/rms_norm/scale_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/rms_norm/scale/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "decoder/rms_norm/scale/adafactor/reduce_mean/reduce/parallel_0/Sum" - op: "Sum" - input: "decoder/rms_norm/scale/adafactor/square_1/parallel_0/Square" - input: "decoder/rms_norm/scale/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "decoder/rms_norm/scale/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.03125 - } - } - } -} -node { - name: "decoder/rms_norm/scale/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/rms_norm/scale/adafactor/reduce_mean/reduce/parallel_0/Sum" - input: "decoder/rms_norm/scale/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/rms_norm/scale/adafactor/sqrt/parallel_0/Sqrt" - op: "Sqrt" - input: "decoder/rms_norm/scale/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/rms_norm/scale/adafactor/add_1/parallel_0/Maximum" - op: "Maximum" - input: "decoder/rms_norm/scale/adafactor/sqrt/parallel_0/Sqrt" - input: "decoder/rms_norm/scale/adafactor/maximum/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/rms_norm/scale/adafactor/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/rms_norm/scale/adafactor/add_1/parallel_0/Maximum" - input: "mul_4" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/rms_norm/scale/adafactor/scalar_mul_1/parallel_0/mul" - op: "Mul" - input: "decoder/rms_norm/scale_slot_v/read" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/rms_norm/scale/adafactor/scalar_mul_2/parallel_0/mul" - op: "Mul" - input: "decoder/rms_norm/scale/adafactor/scalar_add/parallel_0/add" - input: "decoder/rms_norm/scale/adafactor/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/rms_norm/scale/adafactor/add_1_1/parallel_0/Add" - op: "Add" - input: "decoder/rms_norm/scale/adafactor/scalar_mul_1/parallel_0/mul" - input: "decoder/rms_norm/scale/adafactor/scalar_mul_2/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/rms_norm/scale/adafactor/rsqrt/parallel_0/Rsqrt" - op: "Rsqrt" - input: "decoder/rms_norm/scale/adafactor/add_1_1/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/rms_norm/scale/adafactor/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/Exit_52" - input: "decoder/rms_norm/scale/adafactor/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/rms_norm/scale/adafactor/square_2/parallel_0/Square" - op: "Square" - input: "decoder/rms_norm/scale/adafactor/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/rms_norm/scale/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "decoder/rms_norm/scale/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - op: "Sum" - input: "decoder/rms_norm/scale/adafactor/square_2/parallel_0/Square" - input: "decoder/rms_norm/scale/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "decoder/rms_norm/scale/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.03125 - } - } - } -} -node { - name: "decoder/rms_norm/scale/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/rms_norm/scale/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - input: "decoder/rms_norm/scale/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/rms_norm/scale/adafactor/sqrt_1/parallel_0/Sqrt" - op: "Sqrt" - input: "decoder/rms_norm/scale/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/rms_norm/scale/adafactor/scalar_mul_3/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/rms_norm/scale/adafactor/scalar_mul_3/parallel_0/mul" - op: "Mul" - input: "decoder/rms_norm/scale/adafactor/sqrt_1/parallel_0/Sqrt" - input: "decoder/rms_norm/scale/adafactor/scalar_mul_3/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/rms_norm/scale/adafactor/add_2/parallel_0/Maximum" - op: "Maximum" - input: "decoder/rms_norm/scale/adafactor/maximum_1/Const" - input: "decoder/rms_norm/scale/adafactor/scalar_mul_3/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/rms_norm/scale/adafactor/reciprocal/parallel_0/Reciprocal" - op: "Reciprocal" - input: "decoder/rms_norm/scale/adafactor/add_2/parallel_0/Maximum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/rms_norm/scale/adafactor/einsum_1/parallel_0/Mul" - op: "Mul" - input: "decoder/rms_norm/scale/adafactor/einsum/parallel_0/Mul" - input: "decoder/rms_norm/scale/adafactor/reciprocal/parallel_0/Reciprocal" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/rms_norm/scale/adafactor/einsum_2/parallel_0/Mul" - op: "Mul" - input: "decoder/rms_norm/scale/adafactor/einsum_1/parallel_0/Mul" - input: "decoder/rms_norm/scale/adafactor/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "decoder/logits/kernel_slot_v_1/group_deps" - op: "NoOp" -} -node { - name: "decoder/logits/kernel_slot_v_1/group_deps_1" - op: "NoOp" -} -node { - name: "decoder/logits/kernel/adafactor/square/parallel_0/Square" - op: "Square" - input: "while_loop/while/Exit_53" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 32128 - } - } - } - } - } -} -node { - name: "decoder/logits/kernel/adafactor/scalar_add/parallel_0/add/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0000000031710769e-30 - } - } - } -} -node { - name: "decoder/logits/kernel/adafactor/scalar_add/parallel_0/add" - op: "AddV2" - input: "decoder/logits/kernel/adafactor/square/parallel_0/Square" - input: "decoder/logits/kernel/adafactor/scalar_add/parallel_0/add/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 32128 - } - } - } - } - } -} -node { - name: "decoder/logits/kernel/adafactor/square_1/parallel_0/Square" - op: "Square" - input: "decoder/logits/kernel_slice_0/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 32128 - } - } - } - } - } -} -node { - name: "decoder/logits/kernel/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "decoder/logits/kernel/adafactor/reduce_mean/reduce/parallel_0/Sum" - op: "Sum" - input: "decoder/logits/kernel/adafactor/square_1/parallel_0/Square" - input: "decoder/logits/kernel/adafactor/reduce_mean/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "decoder/logits/kernel/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 9.72671841736883e-07 - } - } - } -} -node { - name: "decoder/logits/kernel/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/logits/kernel/adafactor/reduce_mean/reduce/parallel_0/Sum" - input: "decoder/logits/kernel/adafactor/reduce_mean/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/logits/kernel/adafactor/sqrt/parallel_0/Sqrt" - op: "Sqrt" - input: "decoder/logits/kernel/adafactor/reduce_mean/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/logits/kernel/adafactor/add_1/parallel_0/Maximum" - op: "Maximum" - input: "decoder/logits/kernel/adafactor/sqrt/parallel_0/Sqrt" - input: "decoder/logits/kernel/adafactor/maximum/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/logits/kernel/adafactor/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/logits/kernel/adafactor/add_1/parallel_0/Maximum" - input: "mul_4" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/logits/kernel/adafactor/scalar_mul_1/parallel_0/mul" - op: "Mul" - input: "decoder/logits/kernel_slot_v/read" - input: "sub_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 32128 - } - } - } - } - } -} -node { - name: "decoder/logits/kernel/adafactor/scalar_mul_2/parallel_0/mul" - op: "Mul" - input: "decoder/logits/kernel/adafactor/scalar_add/parallel_0/add" - input: "decoder/logits/kernel/adafactor/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 32128 - } - } - } - } - } -} -node { - name: "decoder/logits/kernel/adafactor/add_1_1/parallel_0/Add" - op: "Add" - input: "decoder/logits/kernel/adafactor/scalar_mul_1/parallel_0/mul" - input: "decoder/logits/kernel/adafactor/scalar_mul_2/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 32128 - } - } - } - } - } -} -node { - name: "decoder/logits/kernel/adafactor/rsqrt/parallel_0/Rsqrt" - op: "Rsqrt" - input: "decoder/logits/kernel/adafactor/add_1_1/parallel_0/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 32128 - } - } - } - } - } -} -node { - name: "decoder/logits/kernel/adafactor/einsum/parallel_0/Mul" - op: "Mul" - input: "while_loop/while/Exit_53" - input: "decoder/logits/kernel/adafactor/rsqrt/parallel_0/Rsqrt" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 32128 - } - } - } - } - } -} -node { - name: "decoder/logits/kernel/adafactor/square_2/parallel_0/Square" - op: "Square" - input: "decoder/logits/kernel/adafactor/einsum/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 32128 - } - } - } - } - } -} -node { - name: "decoder/logits/kernel/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\000\000\000\000\001\000\000\000" - } - } - } -} -node { - name: "decoder/logits/kernel/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - op: "Sum" - input: "decoder/logits/kernel/adafactor/square_2/parallel_0/Square" - input: "decoder/logits/kernel/adafactor/reduce_mean_1/reduce/parallel_0/Sum/reduction_indices" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "decoder/logits/kernel/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 9.72671841736883e-07 - } - } - } -} -node { - name: "decoder/logits/kernel/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - op: "Mul" - input: "decoder/logits/kernel/adafactor/reduce_mean_1/reduce/parallel_0/Sum" - input: "decoder/logits/kernel/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/logits/kernel/adafactor/sqrt_1/parallel_0/Sqrt" - op: "Sqrt" - input: "decoder/logits/kernel/adafactor/reduce_mean_1/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/logits/kernel/adafactor/scalar_mul_3/parallel_0/mul/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } -} -node { - name: "decoder/logits/kernel/adafactor/scalar_mul_3/parallel_0/mul" - op: "Mul" - input: "decoder/logits/kernel/adafactor/sqrt_1/parallel_0/Sqrt" - input: "decoder/logits/kernel/adafactor/scalar_mul_3/parallel_0/mul/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/logits/kernel/adafactor/add_2/parallel_0/Maximum" - op: "Maximum" - input: "decoder/logits/kernel/adafactor/maximum_1/Const" - input: "decoder/logits/kernel/adafactor/scalar_mul_3/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/logits/kernel/adafactor/reciprocal/parallel_0/Reciprocal" - op: "Reciprocal" - input: "decoder/logits/kernel/adafactor/add_2/parallel_0/Maximum" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "decoder/logits/kernel/adafactor/einsum_1/parallel_0/Mul" - op: "Mul" - input: "decoder/logits/kernel/adafactor/einsum/parallel_0/Mul" - input: "decoder/logits/kernel/adafactor/reciprocal/parallel_0/Reciprocal" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 32128 - } - } - } - } - } -} -node { - name: "decoder/logits/kernel/adafactor/einsum_2/parallel_0/Mul" - op: "Mul" - input: "decoder/logits/kernel/adafactor/einsum_1/parallel_0/Mul" - input: "decoder/logits/kernel/adafactor/scalar_mul/parallel_0/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 32128 - } - } - } - } - } -} -node { - name: "assign/parallel_0/Assign" - op: "Assign" - input: "shared/embedding_slot_v" - input: "shared/embedding/adafactor/add_1_1/parallel_0/Add" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@shared/embedding_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign/group_deps" - op: "NoOp" - input: "^assign/parallel_0/Assign" - device: "/device:CPU:0" -} -node { - name: "assign/parallel_0_1/Assign" - op: "Assign" - input: "encoder/block_000/layer_000/rms_norm/scale_slot_v" - input: "encoder/block_000/layer_000/rms_norm/scale/adafactor/add_1_1/parallel_0/Add" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign/group_deps_1" - op: "NoOp" - input: "^assign/parallel_0_1/Assign" - device: "/device:CPU:0" -} -node { - name: "assign/parallel_0_2/Assign" - op: "Assign" - input: "encoder/block_000/layer_000/SelfAttention/q_slot_v" - input: "encoder/block_000/layer_000/SelfAttention/q/adafactor/add_1_1/parallel_0/Add" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/q_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign/group_deps_2" - op: "NoOp" - input: "^assign/parallel_0_2/Assign" - device: "/device:CPU:0" -} -node { - name: "assign/parallel_0_3/Assign" - op: "Assign" - input: "encoder/block_000/layer_000/SelfAttention/k_slot_v" - input: "encoder/block_000/layer_000/SelfAttention/k/adafactor/add_1_1/parallel_0/Add" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/k_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign/group_deps_3" - op: "NoOp" - input: "^assign/parallel_0_3/Assign" - device: "/device:CPU:0" -} -node { - name: "assign/parallel_0_4/Assign" - op: "Assign" - input: "encoder/block_000/layer_000/SelfAttention/v_slot_v" - input: "encoder/block_000/layer_000/SelfAttention/v/adafactor/add_1_1/parallel_0/Add" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/v_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign/group_deps_4" - op: "NoOp" - input: "^assign/parallel_0_4/Assign" - device: "/device:CPU:0" -} -node { - name: "assign/parallel_0_5/Assign" - op: "Assign" - input: "encoder/block_000/layer_000/SelfAttention/o_slot_v" - input: "encoder/block_000/layer_000/SelfAttention/o/adafactor/add_1_1/parallel_0/Add" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/o_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign/group_deps_5" - op: "NoOp" - input: "^assign/parallel_0_5/Assign" - device: "/device:CPU:0" -} -node { - name: "assign/parallel_0_6/Assign" - op: "Assign" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias_slot_v" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/add_1_1/parallel_0/Add" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/relative_attention_bias_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign/group_deps_6" - op: "NoOp" - input: "^assign/parallel_0_6/Assign" - device: "/device:CPU:0" -} -node { - name: "assign/parallel_0_7/Assign" - op: "Assign" - input: "encoder/block_000/layer_001/rms_norm/scale_slot_v" - input: "encoder/block_000/layer_001/rms_norm/scale/adafactor/add_1_1/parallel_0/Add" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign/group_deps_7" - op: "NoOp" - input: "^assign/parallel_0_7/Assign" - device: "/device:CPU:0" -} -node { - name: "assign/parallel_0_8/Assign" - op: "Assign" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slot_v" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/adafactor/add_1_1/parallel_0/Add" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign/group_deps_8" - op: "NoOp" - input: "^assign/parallel_0_8/Assign" - device: "/device:CPU:0" -} -node { - name: "assign/parallel_0_9/Assign" - op: "Assign" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slot_v" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/adafactor/add_1_1/parallel_0/Add" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign/group_deps_9" - op: "NoOp" - input: "^assign/parallel_0_9/Assign" - device: "/device:CPU:0" -} -node { - name: "assign/parallel_0_10/Assign" - op: "Assign" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel_slot_v" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/adafactor/add_1_1/parallel_0/Add" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wo/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign/group_deps_10" - op: "NoOp" - input: "^assign/parallel_0_10/Assign" - device: "/device:CPU:0" -} -node { - name: "assign/parallel_0_11/Assign" - op: "Assign" - input: "encoder/block_001/layer_000/rms_norm/scale_slot_v" - input: "encoder/block_001/layer_000/rms_norm/scale/adafactor/add_1_1/parallel_0/Add" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign/group_deps_11" - op: "NoOp" - input: "^assign/parallel_0_11/Assign" - device: "/device:CPU:0" -} -node { - name: "assign/parallel_0_12/Assign" - op: "Assign" - input: "encoder/block_001/layer_000/SelfAttention/q_slot_v" - input: "encoder/block_001/layer_000/SelfAttention/q/adafactor/add_1_1/parallel_0/Add" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/q_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign/group_deps_12" - op: "NoOp" - input: "^assign/parallel_0_12/Assign" - device: "/device:CPU:0" -} -node { - name: "assign/parallel_0_13/Assign" - op: "Assign" - input: "encoder/block_001/layer_000/SelfAttention/k_slot_v" - input: "encoder/block_001/layer_000/SelfAttention/k/adafactor/add_1_1/parallel_0/Add" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/k_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign/group_deps_13" - op: "NoOp" - input: "^assign/parallel_0_13/Assign" - device: "/device:CPU:0" -} -node { - name: "assign/parallel_0_14/Assign" - op: "Assign" - input: "encoder/block_001/layer_000/SelfAttention/v_slot_v" - input: "encoder/block_001/layer_000/SelfAttention/v/adafactor/add_1_1/parallel_0/Add" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/v_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign/group_deps_14" - op: "NoOp" - input: "^assign/parallel_0_14/Assign" - device: "/device:CPU:0" -} -node { - name: "assign/parallel_0_15/Assign" - op: "Assign" - input: "encoder/block_001/layer_000/SelfAttention/o_slot_v" - input: "encoder/block_001/layer_000/SelfAttention/o/adafactor/add_1_1/parallel_0/Add" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/o_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign/group_deps_15" - op: "NoOp" - input: "^assign/parallel_0_15/Assign" - device: "/device:CPU:0" -} -node { - name: "assign/parallel_0_16/Assign" - op: "Assign" - input: "encoder/block_001/layer_001/rms_norm/scale_slot_v" - input: "encoder/block_001/layer_001/rms_norm/scale/adafactor/add_1_1/parallel_0/Add" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign/group_deps_16" - op: "NoOp" - input: "^assign/parallel_0_16/Assign" - device: "/device:CPU:0" -} -node { - name: "assign/parallel_0_17/Assign" - op: "Assign" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slot_v" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/adafactor/add_1_1/parallel_0/Add" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign/group_deps_17" - op: "NoOp" - input: "^assign/parallel_0_17/Assign" - device: "/device:CPU:0" -} -node { - name: "assign/parallel_0_18/Assign" - op: "Assign" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slot_v" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/adafactor/add_1_1/parallel_0/Add" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign/group_deps_18" - op: "NoOp" - input: "^assign/parallel_0_18/Assign" - device: "/device:CPU:0" -} -node { - name: "assign/parallel_0_19/Assign" - op: "Assign" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel_slot_v" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/adafactor/add_1_1/parallel_0/Add" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wo/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign/group_deps_19" - op: "NoOp" - input: "^assign/parallel_0_19/Assign" - device: "/device:CPU:0" -} -node { - name: "assign/parallel_0_20/Assign" - op: "Assign" - input: "encoder/rms_norm/scale_slot_v" - input: "encoder/rms_norm/scale/adafactor/add_1_1/parallel_0/Add" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign/group_deps_20" - op: "NoOp" - input: "^assign/parallel_0_20/Assign" - device: "/device:CPU:0" -} -node { - name: "assign/parallel_0_21/Assign" - op: "Assign" - input: "decoder/block_000/layer_000/rms_norm/scale_slot_v" - input: "decoder/block_000/layer_000/rms_norm/scale/adafactor/add_1_1/parallel_0/Add" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign/group_deps_21" - op: "NoOp" - input: "^assign/parallel_0_21/Assign" - device: "/device:CPU:0" -} -node { - name: "assign/parallel_0_22/Assign" - op: "Assign" - input: "decoder/block_000/layer_000/SelfAttention/q_slot_v" - input: "decoder/block_000/layer_000/SelfAttention/q/adafactor/add_1_1/parallel_0/Add" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/q_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign/group_deps_22" - op: "NoOp" - input: "^assign/parallel_0_22/Assign" - device: "/device:CPU:0" -} -node { - name: "assign/parallel_0_23/Assign" - op: "Assign" - input: "decoder/block_000/layer_000/SelfAttention/k_slot_v" - input: "decoder/block_000/layer_000/SelfAttention/k/adafactor/add_1_1/parallel_0/Add" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/k_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign/group_deps_23" - op: "NoOp" - input: "^assign/parallel_0_23/Assign" - device: "/device:CPU:0" -} -node { - name: "assign/parallel_0_24/Assign" - op: "Assign" - input: "decoder/block_000/layer_000/SelfAttention/v_slot_v" - input: "decoder/block_000/layer_000/SelfAttention/v/adafactor/add_1_1/parallel_0/Add" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/v_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign/group_deps_24" - op: "NoOp" - input: "^assign/parallel_0_24/Assign" - device: "/device:CPU:0" -} -node { - name: "assign/parallel_0_25/Assign" - op: "Assign" - input: "decoder/block_000/layer_000/SelfAttention/o_slot_v" - input: "decoder/block_000/layer_000/SelfAttention/o/adafactor/add_1_1/parallel_0/Add" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/o_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign/group_deps_25" - op: "NoOp" - input: "^assign/parallel_0_25/Assign" - device: "/device:CPU:0" -} -node { - name: "assign/parallel_0_26/Assign" - op: "Assign" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slot_v" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/add_1_1/parallel_0/Add" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign/group_deps_26" - op: "NoOp" - input: "^assign/parallel_0_26/Assign" - device: "/device:CPU:0" -} -node { - name: "assign/parallel_0_27/Assign" - op: "Assign" - input: "decoder/block_000/layer_001/rms_norm/scale_slot_v" - input: "decoder/block_000/layer_001/rms_norm/scale/adafactor/add_1_1/parallel_0/Add" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign/group_deps_27" - op: "NoOp" - input: "^assign/parallel_0_27/Assign" - device: "/device:CPU:0" -} -node { - name: "assign/parallel_0_28/Assign" - op: "Assign" - input: "decoder/block_000/layer_001/EncDecAttention/q_slot_v" - input: "decoder/block_000/layer_001/EncDecAttention/q/adafactor/add_1_1/parallel_0/Add" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/q_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign/group_deps_28" - op: "NoOp" - input: "^assign/parallel_0_28/Assign" - device: "/device:CPU:0" -} -node { - name: "assign/parallel_0_29/Assign" - op: "Assign" - input: "decoder/block_000/layer_001/EncDecAttention/k_slot_v" - input: "decoder/block_000/layer_001/EncDecAttention/k/adafactor/add_1_1/parallel_0/Add" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/k_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign/group_deps_29" - op: "NoOp" - input: "^assign/parallel_0_29/Assign" - device: "/device:CPU:0" -} -node { - name: "assign/parallel_0_30/Assign" - op: "Assign" - input: "decoder/block_000/layer_001/EncDecAttention/v_slot_v" - input: "decoder/block_000/layer_001/EncDecAttention/v/adafactor/add_1_1/parallel_0/Add" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/v_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign/group_deps_30" - op: "NoOp" - input: "^assign/parallel_0_30/Assign" - device: "/device:CPU:0" -} -node { - name: "assign/parallel_0_31/Assign" - op: "Assign" - input: "decoder/block_000/layer_001/EncDecAttention/o_slot_v" - input: "decoder/block_000/layer_001/EncDecAttention/o/adafactor/add_1_1/parallel_0/Add" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/o_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign/group_deps_31" - op: "NoOp" - input: "^assign/parallel_0_31/Assign" - device: "/device:CPU:0" -} -node { - name: "assign/parallel_0_32/Assign" - op: "Assign" - input: "decoder/block_000/layer_002/rms_norm/scale_slot_v" - input: "decoder/block_000/layer_002/rms_norm/scale/adafactor/add_1_1/parallel_0/Add" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign/group_deps_32" - op: "NoOp" - input: "^assign/parallel_0_32/Assign" - device: "/device:CPU:0" -} -node { - name: "assign/parallel_0_33/Assign" - op: "Assign" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slot_v" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/adafactor/add_1_1/parallel_0/Add" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign/group_deps_33" - op: "NoOp" - input: "^assign/parallel_0_33/Assign" - device: "/device:CPU:0" -} -node { - name: "assign/parallel_0_34/Assign" - op: "Assign" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slot_v" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/adafactor/add_1_1/parallel_0/Add" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign/group_deps_34" - op: "NoOp" - input: "^assign/parallel_0_34/Assign" - device: "/device:CPU:0" -} -node { - name: "assign/parallel_0_35/Assign" - op: "Assign" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel_slot_v" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/adafactor/add_1_1/parallel_0/Add" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wo/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign/group_deps_35" - op: "NoOp" - input: "^assign/parallel_0_35/Assign" - device: "/device:CPU:0" -} -node { - name: "assign/parallel_0_36/Assign" - op: "Assign" - input: "decoder/block_001/layer_000/rms_norm/scale_slot_v" - input: "decoder/block_001/layer_000/rms_norm/scale/adafactor/add_1_1/parallel_0/Add" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign/group_deps_36" - op: "NoOp" - input: "^assign/parallel_0_36/Assign" - device: "/device:CPU:0" -} -node { - name: "assign/parallel_0_37/Assign" - op: "Assign" - input: "decoder/block_001/layer_000/SelfAttention/q_slot_v" - input: "decoder/block_001/layer_000/SelfAttention/q/adafactor/add_1_1/parallel_0/Add" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/q_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign/group_deps_37" - op: "NoOp" - input: "^assign/parallel_0_37/Assign" - device: "/device:CPU:0" -} -node { - name: "assign/parallel_0_38/Assign" - op: "Assign" - input: "decoder/block_001/layer_000/SelfAttention/k_slot_v" - input: "decoder/block_001/layer_000/SelfAttention/k/adafactor/add_1_1/parallel_0/Add" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/k_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign/group_deps_38" - op: "NoOp" - input: "^assign/parallel_0_38/Assign" - device: "/device:CPU:0" -} -node { - name: "assign/parallel_0_39/Assign" - op: "Assign" - input: "decoder/block_001/layer_000/SelfAttention/v_slot_v" - input: "decoder/block_001/layer_000/SelfAttention/v/adafactor/add_1_1/parallel_0/Add" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/v_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign/group_deps_39" - op: "NoOp" - input: "^assign/parallel_0_39/Assign" - device: "/device:CPU:0" -} -node { - name: "assign/parallel_0_40/Assign" - op: "Assign" - input: "decoder/block_001/layer_000/SelfAttention/o_slot_v" - input: "decoder/block_001/layer_000/SelfAttention/o/adafactor/add_1_1/parallel_0/Add" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/o_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign/group_deps_40" - op: "NoOp" - input: "^assign/parallel_0_40/Assign" - device: "/device:CPU:0" -} -node { - name: "assign/parallel_0_41/Assign" - op: "Assign" - input: "decoder/block_001/layer_001/rms_norm/scale_slot_v" - input: "decoder/block_001/layer_001/rms_norm/scale/adafactor/add_1_1/parallel_0/Add" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign/group_deps_41" - op: "NoOp" - input: "^assign/parallel_0_41/Assign" - device: "/device:CPU:0" -} -node { - name: "assign/parallel_0_42/Assign" - op: "Assign" - input: "decoder/block_001/layer_001/EncDecAttention/q_slot_v" - input: "decoder/block_001/layer_001/EncDecAttention/q/adafactor/add_1_1/parallel_0/Add" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/q_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign/group_deps_42" - op: "NoOp" - input: "^assign/parallel_0_42/Assign" - device: "/device:CPU:0" -} -node { - name: "assign/parallel_0_43/Assign" - op: "Assign" - input: "decoder/block_001/layer_001/EncDecAttention/k_slot_v" - input: "decoder/block_001/layer_001/EncDecAttention/k/adafactor/add_1_1/parallel_0/Add" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/k_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign/group_deps_43" - op: "NoOp" - input: "^assign/parallel_0_43/Assign" - device: "/device:CPU:0" -} -node { - name: "assign/parallel_0_44/Assign" - op: "Assign" - input: "decoder/block_001/layer_001/EncDecAttention/v_slot_v" - input: "decoder/block_001/layer_001/EncDecAttention/v/adafactor/add_1_1/parallel_0/Add" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/v_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign/group_deps_44" - op: "NoOp" - input: "^assign/parallel_0_44/Assign" - device: "/device:CPU:0" -} -node { - name: "assign/parallel_0_45/Assign" - op: "Assign" - input: "decoder/block_001/layer_001/EncDecAttention/o_slot_v" - input: "decoder/block_001/layer_001/EncDecAttention/o/adafactor/add_1_1/parallel_0/Add" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/o_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign/group_deps_45" - op: "NoOp" - input: "^assign/parallel_0_45/Assign" - device: "/device:CPU:0" -} -node { - name: "assign/parallel_0_46/Assign" - op: "Assign" - input: "decoder/block_001/layer_002/rms_norm/scale_slot_v" - input: "decoder/block_001/layer_002/rms_norm/scale/adafactor/add_1_1/parallel_0/Add" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign/group_deps_46" - op: "NoOp" - input: "^assign/parallel_0_46/Assign" - device: "/device:CPU:0" -} -node { - name: "assign/parallel_0_47/Assign" - op: "Assign" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slot_v" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/adafactor/add_1_1/parallel_0/Add" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign/group_deps_47" - op: "NoOp" - input: "^assign/parallel_0_47/Assign" - device: "/device:CPU:0" -} -node { - name: "assign/parallel_0_48/Assign" - op: "Assign" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slot_v" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/adafactor/add_1_1/parallel_0/Add" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign/group_deps_48" - op: "NoOp" - input: "^assign/parallel_0_48/Assign" - device: "/device:CPU:0" -} -node { - name: "assign/parallel_0_49/Assign" - op: "Assign" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel_slot_v" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/adafactor/add_1_1/parallel_0/Add" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wo/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign/group_deps_49" - op: "NoOp" - input: "^assign/parallel_0_49/Assign" - device: "/device:CPU:0" -} -node { - name: "assign/parallel_0_50/Assign" - op: "Assign" - input: "decoder/rms_norm/scale_slot_v" - input: "decoder/rms_norm/scale/adafactor/add_1_1/parallel_0/Add" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign/group_deps_50" - op: "NoOp" - input: "^assign/parallel_0_50/Assign" - device: "/device:CPU:0" -} -node { - name: "assign/parallel_0_51/Assign" - op: "Assign" - input: "decoder/logits/kernel_slot_v" - input: "decoder/logits/kernel/adafactor/add_1_1/parallel_0/Add" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/logits/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 32128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign/group_deps_51" - op: "NoOp" - input: "^assign/parallel_0_51/Assign" - device: "/device:CPU:0" -} -node { - name: "assign/group_deps_52" - op: "NoOp" - input: "^assign/group_deps" - input: "^assign/group_deps_1" - input: "^assign/group_deps_10" - input: "^assign/group_deps_11" - input: "^assign/group_deps_12" - input: "^assign/group_deps_13" - input: "^assign/group_deps_14" - input: "^assign/group_deps_15" - input: "^assign/group_deps_16" - input: "^assign/group_deps_17" - input: "^assign/group_deps_18" - input: "^assign/group_deps_19" - input: "^assign/group_deps_2" - input: "^assign/group_deps_20" - input: "^assign/group_deps_21" - input: "^assign/group_deps_22" - input: "^assign/group_deps_23" - input: "^assign/group_deps_24" - input: "^assign/group_deps_25" - input: "^assign/group_deps_26" - input: "^assign/group_deps_27" - input: "^assign/group_deps_28" - input: "^assign/group_deps_29" - input: "^assign/group_deps_3" - input: "^assign/group_deps_30" - input: "^assign/group_deps_31" - input: "^assign/group_deps_32" - input: "^assign/group_deps_33" - input: "^assign/group_deps_34" - input: "^assign/group_deps_35" - input: "^assign/group_deps_36" - input: "^assign/group_deps_37" - input: "^assign/group_deps_38" - input: "^assign/group_deps_39" - input: "^assign/group_deps_4" - input: "^assign/group_deps_40" - input: "^assign/group_deps_41" - input: "^assign/group_deps_42" - input: "^assign/group_deps_43" - input: "^assign/group_deps_44" - input: "^assign/group_deps_45" - input: "^assign/group_deps_46" - input: "^assign/group_deps_47" - input: "^assign/group_deps_48" - input: "^assign/group_deps_49" - input: "^assign/group_deps_5" - input: "^assign/group_deps_50" - input: "^assign/group_deps_51" - input: "^assign/group_deps_6" - input: "^assign/group_deps_7" - input: "^assign/group_deps_8" - input: "^assign/group_deps_9" - device: "/device:CPU:0" -} -node { - name: "assign_1/parallel_0/sub" - op: "Sub" - input: "shared/embedding_slice_0/read" - input: "shared/embedding/adafactor/einsum_2/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "assign_1/parallel_0/Assign" - op: "Assign" - input: "shared/embedding_slice_0" - input: "assign_1/parallel_0/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@shared/embedding_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign_1/group_deps" - op: "NoOp" - input: "^assign_1/parallel_0/Assign" -} -node { - name: "assign_1/parallel_0_1/sub" - op: "Sub" - input: "encoder/block_000/layer_000/rms_norm/scale_slice_0/read" - input: "encoder/block_000/layer_000/rms_norm/scale/adafactor/einsum_2/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "assign_1/parallel_0_1/Assign" - op: "Assign" - input: "encoder/block_000/layer_000/rms_norm/scale_slice_0" - input: "assign_1/parallel_0_1/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign_1/group_deps_1" - op: "NoOp" - input: "^assign_1/parallel_0_1/Assign" -} -node { - name: "assign_1/parallel_0_2/sub" - op: "Sub" - input: "encoder/block_000/layer_000/SelfAttention/q_slice_0/read" - input: "encoder/block_000/layer_000/SelfAttention/q/adafactor/einsum_2/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "assign_1/parallel_0_2/Assign" - op: "Assign" - input: "encoder/block_000/layer_000/SelfAttention/q_slice_0" - input: "assign_1/parallel_0_2/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/q_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign_1/group_deps_2" - op: "NoOp" - input: "^assign_1/parallel_0_2/Assign" -} -node { - name: "assign_1/parallel_0_3/sub" - op: "Sub" - input: "encoder/block_000/layer_000/SelfAttention/k_slice_0/read" - input: "encoder/block_000/layer_000/SelfAttention/k/adafactor/einsum_2/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "assign_1/parallel_0_3/Assign" - op: "Assign" - input: "encoder/block_000/layer_000/SelfAttention/k_slice_0" - input: "assign_1/parallel_0_3/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/k_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign_1/group_deps_3" - op: "NoOp" - input: "^assign_1/parallel_0_3/Assign" -} -node { - name: "assign_1/parallel_0_4/sub" - op: "Sub" - input: "encoder/block_000/layer_000/SelfAttention/v_slice_0/read" - input: "encoder/block_000/layer_000/SelfAttention/v/adafactor/einsum_2/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "assign_1/parallel_0_4/Assign" - op: "Assign" - input: "encoder/block_000/layer_000/SelfAttention/v_slice_0" - input: "assign_1/parallel_0_4/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/v_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign_1/group_deps_4" - op: "NoOp" - input: "^assign_1/parallel_0_4/Assign" -} -node { - name: "assign_1/parallel_0_5/sub" - op: "Sub" - input: "encoder/block_000/layer_000/SelfAttention/o_slice_0/read" - input: "encoder/block_000/layer_000/SelfAttention/o/adafactor/einsum_2/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "assign_1/parallel_0_5/Assign" - op: "Assign" - input: "encoder/block_000/layer_000/SelfAttention/o_slice_0" - input: "assign_1/parallel_0_5/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/o_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign_1/group_deps_5" - op: "NoOp" - input: "^assign_1/parallel_0_5/Assign" -} -node { - name: "assign_1/parallel_0_6/sub" - op: "Sub" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0/read" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/einsum_2/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "assign_1/parallel_0_6/Assign" - op: "Assign" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0" - input: "assign_1/parallel_0_6/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign_1/group_deps_6" - op: "NoOp" - input: "^assign_1/parallel_0_6/Assign" -} -node { - name: "assign_1/parallel_0_7/sub" - op: "Sub" - input: "encoder/block_000/layer_001/rms_norm/scale_slice_0/read" - input: "encoder/block_000/layer_001/rms_norm/scale/adafactor/einsum_2/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "assign_1/parallel_0_7/Assign" - op: "Assign" - input: "encoder/block_000/layer_001/rms_norm/scale_slice_0" - input: "assign_1/parallel_0_7/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign_1/group_deps_7" - op: "NoOp" - input: "^assign_1/parallel_0_7/Assign" -} -node { - name: "assign_1/parallel_0_8/sub" - op: "Sub" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slice_0/read" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/adafactor/einsum_2/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "assign_1/parallel_0_8/Assign" - op: "Assign" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slice_0" - input: "assign_1/parallel_0_8/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign_1/group_deps_8" - op: "NoOp" - input: "^assign_1/parallel_0_8/Assign" -} -node { - name: "assign_1/parallel_0_9/sub" - op: "Sub" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slice_0/read" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/adafactor/einsum_2/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "assign_1/parallel_0_9/Assign" - op: "Assign" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slice_0" - input: "assign_1/parallel_0_9/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign_1/group_deps_9" - op: "NoOp" - input: "^assign_1/parallel_0_9/Assign" -} -node { - name: "assign_1/parallel_0_10/sub" - op: "Sub" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel_slice_0/read" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel/adafactor/einsum_2/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "assign_1/parallel_0_10/Assign" - op: "Assign" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel_slice_0" - input: "assign_1/parallel_0_10/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wo/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign_1/group_deps_10" - op: "NoOp" - input: "^assign_1/parallel_0_10/Assign" -} -node { - name: "assign_1/parallel_0_11/sub" - op: "Sub" - input: "encoder/block_001/layer_000/rms_norm/scale_slice_0/read" - input: "encoder/block_001/layer_000/rms_norm/scale/adafactor/einsum_2/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "assign_1/parallel_0_11/Assign" - op: "Assign" - input: "encoder/block_001/layer_000/rms_norm/scale_slice_0" - input: "assign_1/parallel_0_11/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign_1/group_deps_11" - op: "NoOp" - input: "^assign_1/parallel_0_11/Assign" -} -node { - name: "assign_1/parallel_0_12/sub" - op: "Sub" - input: "encoder/block_001/layer_000/SelfAttention/q_slice_0/read" - input: "encoder/block_001/layer_000/SelfAttention/q/adafactor/einsum_2/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "assign_1/parallel_0_12/Assign" - op: "Assign" - input: "encoder/block_001/layer_000/SelfAttention/q_slice_0" - input: "assign_1/parallel_0_12/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/q_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign_1/group_deps_12" - op: "NoOp" - input: "^assign_1/parallel_0_12/Assign" -} -node { - name: "assign_1/parallel_0_13/sub" - op: "Sub" - input: "encoder/block_001/layer_000/SelfAttention/k_slice_0/read" - input: "encoder/block_001/layer_000/SelfAttention/k/adafactor/einsum_2/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "assign_1/parallel_0_13/Assign" - op: "Assign" - input: "encoder/block_001/layer_000/SelfAttention/k_slice_0" - input: "assign_1/parallel_0_13/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/k_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign_1/group_deps_13" - op: "NoOp" - input: "^assign_1/parallel_0_13/Assign" -} -node { - name: "assign_1/parallel_0_14/sub" - op: "Sub" - input: "encoder/block_001/layer_000/SelfAttention/v_slice_0/read" - input: "encoder/block_001/layer_000/SelfAttention/v/adafactor/einsum_2/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "assign_1/parallel_0_14/Assign" - op: "Assign" - input: "encoder/block_001/layer_000/SelfAttention/v_slice_0" - input: "assign_1/parallel_0_14/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/v_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign_1/group_deps_14" - op: "NoOp" - input: "^assign_1/parallel_0_14/Assign" -} -node { - name: "assign_1/parallel_0_15/sub" - op: "Sub" - input: "encoder/block_001/layer_000/SelfAttention/o_slice_0/read" - input: "encoder/block_001/layer_000/SelfAttention/o/adafactor/einsum_2/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "assign_1/parallel_0_15/Assign" - op: "Assign" - input: "encoder/block_001/layer_000/SelfAttention/o_slice_0" - input: "assign_1/parallel_0_15/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/o_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign_1/group_deps_15" - op: "NoOp" - input: "^assign_1/parallel_0_15/Assign" -} -node { - name: "assign_1/parallel_0_16/sub" - op: "Sub" - input: "encoder/block_001/layer_001/rms_norm/scale_slice_0/read" - input: "encoder/block_001/layer_001/rms_norm/scale/adafactor/einsum_2/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "assign_1/parallel_0_16/Assign" - op: "Assign" - input: "encoder/block_001/layer_001/rms_norm/scale_slice_0" - input: "assign_1/parallel_0_16/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign_1/group_deps_16" - op: "NoOp" - input: "^assign_1/parallel_0_16/Assign" -} -node { - name: "assign_1/parallel_0_17/sub" - op: "Sub" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slice_0/read" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/adafactor/einsum_2/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "assign_1/parallel_0_17/Assign" - op: "Assign" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slice_0" - input: "assign_1/parallel_0_17/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign_1/group_deps_17" - op: "NoOp" - input: "^assign_1/parallel_0_17/Assign" -} -node { - name: "assign_1/parallel_0_18/sub" - op: "Sub" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slice_0/read" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/adafactor/einsum_2/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "assign_1/parallel_0_18/Assign" - op: "Assign" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slice_0" - input: "assign_1/parallel_0_18/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign_1/group_deps_18" - op: "NoOp" - input: "^assign_1/parallel_0_18/Assign" -} -node { - name: "assign_1/parallel_0_19/sub" - op: "Sub" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel_slice_0/read" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel/adafactor/einsum_2/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "assign_1/parallel_0_19/Assign" - op: "Assign" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel_slice_0" - input: "assign_1/parallel_0_19/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wo/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign_1/group_deps_19" - op: "NoOp" - input: "^assign_1/parallel_0_19/Assign" -} -node { - name: "assign_1/parallel_0_20/sub" - op: "Sub" - input: "encoder/rms_norm/scale_slice_0/read" - input: "encoder/rms_norm/scale/adafactor/einsum_2/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "assign_1/parallel_0_20/Assign" - op: "Assign" - input: "encoder/rms_norm/scale_slice_0" - input: "assign_1/parallel_0_20/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign_1/group_deps_20" - op: "NoOp" - input: "^assign_1/parallel_0_20/Assign" -} -node { - name: "assign_1/parallel_0_21/sub" - op: "Sub" - input: "decoder/block_000/layer_000/rms_norm/scale_slice_0/read" - input: "decoder/block_000/layer_000/rms_norm/scale/adafactor/einsum_2/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "assign_1/parallel_0_21/Assign" - op: "Assign" - input: "decoder/block_000/layer_000/rms_norm/scale_slice_0" - input: "assign_1/parallel_0_21/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign_1/group_deps_21" - op: "NoOp" - input: "^assign_1/parallel_0_21/Assign" -} -node { - name: "assign_1/parallel_0_22/sub" - op: "Sub" - input: "decoder/block_000/layer_000/SelfAttention/q_slice_0/read" - input: "decoder/block_000/layer_000/SelfAttention/q/adafactor/einsum_2/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "assign_1/parallel_0_22/Assign" - op: "Assign" - input: "decoder/block_000/layer_000/SelfAttention/q_slice_0" - input: "assign_1/parallel_0_22/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/q_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign_1/group_deps_22" - op: "NoOp" - input: "^assign_1/parallel_0_22/Assign" -} -node { - name: "assign_1/parallel_0_23/sub" - op: "Sub" - input: "decoder/block_000/layer_000/SelfAttention/k_slice_0/read" - input: "decoder/block_000/layer_000/SelfAttention/k/adafactor/einsum_2/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "assign_1/parallel_0_23/Assign" - op: "Assign" - input: "decoder/block_000/layer_000/SelfAttention/k_slice_0" - input: "assign_1/parallel_0_23/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/k_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign_1/group_deps_23" - op: "NoOp" - input: "^assign_1/parallel_0_23/Assign" -} -node { - name: "assign_1/parallel_0_24/sub" - op: "Sub" - input: "decoder/block_000/layer_000/SelfAttention/v_slice_0/read" - input: "decoder/block_000/layer_000/SelfAttention/v/adafactor/einsum_2/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "assign_1/parallel_0_24/Assign" - op: "Assign" - input: "decoder/block_000/layer_000/SelfAttention/v_slice_0" - input: "assign_1/parallel_0_24/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/v_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign_1/group_deps_24" - op: "NoOp" - input: "^assign_1/parallel_0_24/Assign" -} -node { - name: "assign_1/parallel_0_25/sub" - op: "Sub" - input: "decoder/block_000/layer_000/SelfAttention/o_slice_0/read" - input: "decoder/block_000/layer_000/SelfAttention/o/adafactor/einsum_2/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "assign_1/parallel_0_25/Assign" - op: "Assign" - input: "decoder/block_000/layer_000/SelfAttention/o_slice_0" - input: "assign_1/parallel_0_25/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/o_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign_1/group_deps_25" - op: "NoOp" - input: "^assign_1/parallel_0_25/Assign" -} -node { - name: "assign_1/parallel_0_26/sub" - op: "Sub" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0/read" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias/adafactor/einsum_2/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "assign_1/parallel_0_26/Assign" - op: "Assign" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0" - input: "assign_1/parallel_0_26/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign_1/group_deps_26" - op: "NoOp" - input: "^assign_1/parallel_0_26/Assign" -} -node { - name: "assign_1/parallel_0_27/sub" - op: "Sub" - input: "decoder/block_000/layer_001/rms_norm/scale_slice_0/read" - input: "decoder/block_000/layer_001/rms_norm/scale/adafactor/einsum_2/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "assign_1/parallel_0_27/Assign" - op: "Assign" - input: "decoder/block_000/layer_001/rms_norm/scale_slice_0" - input: "assign_1/parallel_0_27/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign_1/group_deps_27" - op: "NoOp" - input: "^assign_1/parallel_0_27/Assign" -} -node { - name: "assign_1/parallel_0_28/sub" - op: "Sub" - input: "decoder/block_000/layer_001/EncDecAttention/q_slice_0/read" - input: "decoder/block_000/layer_001/EncDecAttention/q/adafactor/einsum_2/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "assign_1/parallel_0_28/Assign" - op: "Assign" - input: "decoder/block_000/layer_001/EncDecAttention/q_slice_0" - input: "assign_1/parallel_0_28/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/q_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign_1/group_deps_28" - op: "NoOp" - input: "^assign_1/parallel_0_28/Assign" -} -node { - name: "assign_1/parallel_0_29/sub" - op: "Sub" - input: "decoder/block_000/layer_001/EncDecAttention/k_slice_0/read" - input: "decoder/block_000/layer_001/EncDecAttention/k/adafactor/einsum_2/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "assign_1/parallel_0_29/Assign" - op: "Assign" - input: "decoder/block_000/layer_001/EncDecAttention/k_slice_0" - input: "assign_1/parallel_0_29/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/k_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign_1/group_deps_29" - op: "NoOp" - input: "^assign_1/parallel_0_29/Assign" -} -node { - name: "assign_1/parallel_0_30/sub" - op: "Sub" - input: "decoder/block_000/layer_001/EncDecAttention/v_slice_0/read" - input: "decoder/block_000/layer_001/EncDecAttention/v/adafactor/einsum_2/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "assign_1/parallel_0_30/Assign" - op: "Assign" - input: "decoder/block_000/layer_001/EncDecAttention/v_slice_0" - input: "assign_1/parallel_0_30/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/v_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign_1/group_deps_30" - op: "NoOp" - input: "^assign_1/parallel_0_30/Assign" -} -node { - name: "assign_1/parallel_0_31/sub" - op: "Sub" - input: "decoder/block_000/layer_001/EncDecAttention/o_slice_0/read" - input: "decoder/block_000/layer_001/EncDecAttention/o/adafactor/einsum_2/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "assign_1/parallel_0_31/Assign" - op: "Assign" - input: "decoder/block_000/layer_001/EncDecAttention/o_slice_0" - input: "assign_1/parallel_0_31/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/o_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign_1/group_deps_31" - op: "NoOp" - input: "^assign_1/parallel_0_31/Assign" -} -node { - name: "assign_1/parallel_0_32/sub" - op: "Sub" - input: "decoder/block_000/layer_002/rms_norm/scale_slice_0/read" - input: "decoder/block_000/layer_002/rms_norm/scale/adafactor/einsum_2/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "assign_1/parallel_0_32/Assign" - op: "Assign" - input: "decoder/block_000/layer_002/rms_norm/scale_slice_0" - input: "assign_1/parallel_0_32/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign_1/group_deps_32" - op: "NoOp" - input: "^assign_1/parallel_0_32/Assign" -} -node { - name: "assign_1/parallel_0_33/sub" - op: "Sub" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slice_0/read" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/adafactor/einsum_2/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "assign_1/parallel_0_33/Assign" - op: "Assign" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slice_0" - input: "assign_1/parallel_0_33/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign_1/group_deps_33" - op: "NoOp" - input: "^assign_1/parallel_0_33/Assign" -} -node { - name: "assign_1/parallel_0_34/sub" - op: "Sub" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slice_0/read" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/adafactor/einsum_2/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "assign_1/parallel_0_34/Assign" - op: "Assign" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slice_0" - input: "assign_1/parallel_0_34/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign_1/group_deps_34" - op: "NoOp" - input: "^assign_1/parallel_0_34/Assign" -} -node { - name: "assign_1/parallel_0_35/sub" - op: "Sub" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel_slice_0/read" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel/adafactor/einsum_2/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "assign_1/parallel_0_35/Assign" - op: "Assign" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel_slice_0" - input: "assign_1/parallel_0_35/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wo/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign_1/group_deps_35" - op: "NoOp" - input: "^assign_1/parallel_0_35/Assign" -} -node { - name: "assign_1/parallel_0_36/sub" - op: "Sub" - input: "decoder/block_001/layer_000/rms_norm/scale_slice_0/read" - input: "decoder/block_001/layer_000/rms_norm/scale/adafactor/einsum_2/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "assign_1/parallel_0_36/Assign" - op: "Assign" - input: "decoder/block_001/layer_000/rms_norm/scale_slice_0" - input: "assign_1/parallel_0_36/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign_1/group_deps_36" - op: "NoOp" - input: "^assign_1/parallel_0_36/Assign" -} -node { - name: "assign_1/parallel_0_37/sub" - op: "Sub" - input: "decoder/block_001/layer_000/SelfAttention/q_slice_0/read" - input: "decoder/block_001/layer_000/SelfAttention/q/adafactor/einsum_2/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "assign_1/parallel_0_37/Assign" - op: "Assign" - input: "decoder/block_001/layer_000/SelfAttention/q_slice_0" - input: "assign_1/parallel_0_37/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/q_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign_1/group_deps_37" - op: "NoOp" - input: "^assign_1/parallel_0_37/Assign" -} -node { - name: "assign_1/parallel_0_38/sub" - op: "Sub" - input: "decoder/block_001/layer_000/SelfAttention/k_slice_0/read" - input: "decoder/block_001/layer_000/SelfAttention/k/adafactor/einsum_2/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "assign_1/parallel_0_38/Assign" - op: "Assign" - input: "decoder/block_001/layer_000/SelfAttention/k_slice_0" - input: "assign_1/parallel_0_38/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/k_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign_1/group_deps_38" - op: "NoOp" - input: "^assign_1/parallel_0_38/Assign" -} -node { - name: "assign_1/parallel_0_39/sub" - op: "Sub" - input: "decoder/block_001/layer_000/SelfAttention/v_slice_0/read" - input: "decoder/block_001/layer_000/SelfAttention/v/adafactor/einsum_2/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "assign_1/parallel_0_39/Assign" - op: "Assign" - input: "decoder/block_001/layer_000/SelfAttention/v_slice_0" - input: "assign_1/parallel_0_39/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/v_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign_1/group_deps_39" - op: "NoOp" - input: "^assign_1/parallel_0_39/Assign" -} -node { - name: "assign_1/parallel_0_40/sub" - op: "Sub" - input: "decoder/block_001/layer_000/SelfAttention/o_slice_0/read" - input: "decoder/block_001/layer_000/SelfAttention/o/adafactor/einsum_2/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "assign_1/parallel_0_40/Assign" - op: "Assign" - input: "decoder/block_001/layer_000/SelfAttention/o_slice_0" - input: "assign_1/parallel_0_40/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/o_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign_1/group_deps_40" - op: "NoOp" - input: "^assign_1/parallel_0_40/Assign" -} -node { - name: "assign_1/parallel_0_41/sub" - op: "Sub" - input: "decoder/block_001/layer_001/rms_norm/scale_slice_0/read" - input: "decoder/block_001/layer_001/rms_norm/scale/adafactor/einsum_2/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "assign_1/parallel_0_41/Assign" - op: "Assign" - input: "decoder/block_001/layer_001/rms_norm/scale_slice_0" - input: "assign_1/parallel_0_41/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign_1/group_deps_41" - op: "NoOp" - input: "^assign_1/parallel_0_41/Assign" -} -node { - name: "assign_1/parallel_0_42/sub" - op: "Sub" - input: "decoder/block_001/layer_001/EncDecAttention/q_slice_0/read" - input: "decoder/block_001/layer_001/EncDecAttention/q/adafactor/einsum_2/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "assign_1/parallel_0_42/Assign" - op: "Assign" - input: "decoder/block_001/layer_001/EncDecAttention/q_slice_0" - input: "assign_1/parallel_0_42/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/q_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign_1/group_deps_42" - op: "NoOp" - input: "^assign_1/parallel_0_42/Assign" -} -node { - name: "assign_1/parallel_0_43/sub" - op: "Sub" - input: "decoder/block_001/layer_001/EncDecAttention/k_slice_0/read" - input: "decoder/block_001/layer_001/EncDecAttention/k/adafactor/einsum_2/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "assign_1/parallel_0_43/Assign" - op: "Assign" - input: "decoder/block_001/layer_001/EncDecAttention/k_slice_0" - input: "assign_1/parallel_0_43/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/k_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign_1/group_deps_43" - op: "NoOp" - input: "^assign_1/parallel_0_43/Assign" -} -node { - name: "assign_1/parallel_0_44/sub" - op: "Sub" - input: "decoder/block_001/layer_001/EncDecAttention/v_slice_0/read" - input: "decoder/block_001/layer_001/EncDecAttention/v/adafactor/einsum_2/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } -} -node { - name: "assign_1/parallel_0_44/Assign" - op: "Assign" - input: "decoder/block_001/layer_001/EncDecAttention/v_slice_0" - input: "assign_1/parallel_0_44/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/v_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign_1/group_deps_44" - op: "NoOp" - input: "^assign_1/parallel_0_44/Assign" -} -node { - name: "assign_1/parallel_0_45/sub" - op: "Sub" - input: "decoder/block_001/layer_001/EncDecAttention/o_slice_0/read" - input: "decoder/block_001/layer_001/EncDecAttention/o/adafactor/einsum_2/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "assign_1/parallel_0_45/Assign" - op: "Assign" - input: "decoder/block_001/layer_001/EncDecAttention/o_slice_0" - input: "assign_1/parallel_0_45/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/o_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign_1/group_deps_45" - op: "NoOp" - input: "^assign_1/parallel_0_45/Assign" -} -node { - name: "assign_1/parallel_0_46/sub" - op: "Sub" - input: "decoder/block_001/layer_002/rms_norm/scale_slice_0/read" - input: "decoder/block_001/layer_002/rms_norm/scale/adafactor/einsum_2/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "assign_1/parallel_0_46/Assign" - op: "Assign" - input: "decoder/block_001/layer_002/rms_norm/scale_slice_0" - input: "assign_1/parallel_0_46/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign_1/group_deps_46" - op: "NoOp" - input: "^assign_1/parallel_0_46/Assign" -} -node { - name: "assign_1/parallel_0_47/sub" - op: "Sub" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slice_0/read" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/adafactor/einsum_2/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "assign_1/parallel_0_47/Assign" - op: "Assign" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slice_0" - input: "assign_1/parallel_0_47/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign_1/group_deps_47" - op: "NoOp" - input: "^assign_1/parallel_0_47/Assign" -} -node { - name: "assign_1/parallel_0_48/sub" - op: "Sub" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slice_0/read" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/adafactor/einsum_2/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } -} -node { - name: "assign_1/parallel_0_48/Assign" - op: "Assign" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slice_0" - input: "assign_1/parallel_0_48/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign_1/group_deps_48" - op: "NoOp" - input: "^assign_1/parallel_0_48/Assign" -} -node { - name: "assign_1/parallel_0_49/sub" - op: "Sub" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel_slice_0/read" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel/adafactor/einsum_2/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } -} -node { - name: "assign_1/parallel_0_49/Assign" - op: "Assign" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel_slice_0" - input: "assign_1/parallel_0_49/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wo/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign_1/group_deps_49" - op: "NoOp" - input: "^assign_1/parallel_0_49/Assign" -} -node { - name: "assign_1/parallel_0_50/sub" - op: "Sub" - input: "decoder/rms_norm/scale_slice_0/read" - input: "decoder/rms_norm/scale/adafactor/einsum_2/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } -} -node { - name: "assign_1/parallel_0_50/Assign" - op: "Assign" - input: "decoder/rms_norm/scale_slice_0" - input: "assign_1/parallel_0_50/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/rms_norm/scale_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign_1/group_deps_50" - op: "NoOp" - input: "^assign_1/parallel_0_50/Assign" -} -node { - name: "assign_1/parallel_0_51/sub" - op: "Sub" - input: "decoder/logits/kernel_slice_0/read" - input: "decoder/logits/kernel/adafactor/einsum_2/parallel_0/Mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 32128 - } - } - } - } - } -} -node { - name: "assign_1/parallel_0_51/Assign" - op: "Assign" - input: "decoder/logits/kernel_slice_0" - input: "assign_1/parallel_0_51/sub" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/logits/kernel_slice_0" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 32128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "assign_1/group_deps_51" - op: "NoOp" - input: "^assign_1/parallel_0_51/Assign" -} -node { - name: "assign_1/group_deps_52" - op: "NoOp" - input: "^assign_1/group_deps" - input: "^assign_1/group_deps_1" - input: "^assign_1/group_deps_10" - input: "^assign_1/group_deps_11" - input: "^assign_1/group_deps_12" - input: "^assign_1/group_deps_13" - input: "^assign_1/group_deps_14" - input: "^assign_1/group_deps_15" - input: "^assign_1/group_deps_16" - input: "^assign_1/group_deps_17" - input: "^assign_1/group_deps_18" - input: "^assign_1/group_deps_19" - input: "^assign_1/group_deps_2" - input: "^assign_1/group_deps_20" - input: "^assign_1/group_deps_21" - input: "^assign_1/group_deps_22" - input: "^assign_1/group_deps_23" - input: "^assign_1/group_deps_24" - input: "^assign_1/group_deps_25" - input: "^assign_1/group_deps_26" - input: "^assign_1/group_deps_27" - input: "^assign_1/group_deps_28" - input: "^assign_1/group_deps_29" - input: "^assign_1/group_deps_3" - input: "^assign_1/group_deps_30" - input: "^assign_1/group_deps_31" - input: "^assign_1/group_deps_32" - input: "^assign_1/group_deps_33" - input: "^assign_1/group_deps_34" - input: "^assign_1/group_deps_35" - input: "^assign_1/group_deps_36" - input: "^assign_1/group_deps_37" - input: "^assign_1/group_deps_38" - input: "^assign_1/group_deps_39" - input: "^assign_1/group_deps_4" - input: "^assign_1/group_deps_40" - input: "^assign_1/group_deps_41" - input: "^assign_1/group_deps_42" - input: "^assign_1/group_deps_43" - input: "^assign_1/group_deps_44" - input: "^assign_1/group_deps_45" - input: "^assign_1/group_deps_46" - input: "^assign_1/group_deps_47" - input: "^assign_1/group_deps_48" - input: "^assign_1/group_deps_49" - input: "^assign_1/group_deps_5" - input: "^assign_1/group_deps_50" - input: "^assign_1/group_deps_51" - input: "^assign_1/group_deps_6" - input: "^assign_1/group_deps_7" - input: "^assign_1/group_deps_8" - input: "^assign_1/group_deps_9" -} -node { - name: "Print/ReadVariableOp" - op: "ReadVariableOp" - input: "global_step" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT64 - } - } -} -node { - name: "Print" - op: "Print" - input: "while_loop/while/Exit_1" - input: "while_loop/while/Exit_1" - input: "Print/ReadVariableOp" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "U" - value { - list { - type: DT_FLOAT - type: DT_INT64 - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "first_n" - value { - i: -1 - } - } - attr { - key: "message" - value { - s: "step, tf_loss" - } - } - attr { - key: "summarize" - value { - i: 3 - } - } -} -node { - name: "Const_3" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: 1 - } - } - } -} -node { - name: "AssignAddVariableOp" - op: "AssignAddVariableOp" - input: "global_step" - input: "Const_3" - attr { - key: "dtype" - value { - type: DT_INT64 - } - } -} -node { - name: "ReadVariableOp_4" - op: "ReadVariableOp" - input: "global_step" - input: "^AssignAddVariableOp" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT64 - } - } -} -node { - name: "group_deps/NoOp" - op: "NoOp" - input: "^AssignAddVariableOp" - input: "^assign_1/group_deps_52" -} -node { - name: "group_deps/NoOp_1" - op: "NoOp" - input: "^assign/group_deps_52" - device: "/device:CPU:0" -} -node { - name: "group_deps" - op: "NoOp" - input: "^group_deps/NoOp" - input: "^group_deps/NoOp_1" -} -node { - name: "save/filename/input" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - } - string_val: "model" - } - } - } -} -node { - name: "save/filename" - op: "PlaceholderWithDefault" - input: "save/filename/input" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "shape" - value { - shape { - } - } - } -} -node { - name: "save/Const" - op: "PlaceholderWithDefault" - input: "save/filename" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "shape" - value { - shape { - } - } - } -} -node { - name: "save/StaticRegexFullMatch" - op: "StaticRegexFullMatch" - input: "save/Const" - device: "/device:CPU:*" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "pattern" - value { - s: "^s3://.*" - } - } -} -node { - name: "save/Const_1" - op: "Const" - device: "/device:CPU:*" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - } - string_val: ".part" - } - } - } -} -node { - name: "save/Const_2" - op: "Const" - device: "/device:CPU:*" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - } - string_val: "_temp/part" - } - } - } -} -node { - name: "save/Select" - op: "Select" - input: "save/StaticRegexFullMatch" - input: "save/Const_1" - input: "save/Const_2" - device: "/device:CPU:*" - attr { - key: "T" - value { - type: DT_STRING - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "save/StringJoin" - op: "StringJoin" - input: "save/Const" - input: "save/Select" - device: "/device:CPU:*" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "separator" - value { - s: "" - } - } -} -node { - name: "save/num_shards" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 2 - } - } - } -} -node { - name: "save/ShardedFilename/shard" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "save/ShardedFilename" - op: "ShardedFilename" - input: "save/StringJoin" - input: "save/ShardedFilename/shard" - input: "save/num_shards" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "save/SaveV2/tensor_names" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - dim { - size: 1 - } - } - string_val: "global_step" - } - } - } -} -node { - name: "save/SaveV2/shape_and_slices" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - dim { - size: 1 - } - } - string_val: "" - } - } - } -} -node { - name: "save/SaveV2" - op: "SaveV2" - input: "save/ShardedFilename" - input: "save/SaveV2/tensor_names" - input: "save/SaveV2/shape_and_slices" - input: "global_step/Read/ReadVariableOp" - device: "/device:CPU:0" - attr { - key: "dtypes" - value { - list { - type: DT_INT64 - } - } - } -} -node { - name: "save/control_dependency" - op: "Identity" - input: "save/ShardedFilename" - input: "^save/SaveV2" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_STRING - } - } - attr { - key: "_class" - value { - list { - s: "loc:@save/ShardedFilename" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "save/ShardedFilename_1/shard" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "save/ShardedFilename_1" - op: "ShardedFilename" - input: "save/StringJoin" - input: "save/ShardedFilename_1/shard" - input: "save/num_shards" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "save/SaveV2_1/tensor_names" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 104 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - dim { - size: 104 - } - } - string_val: "decoder/block_000/layer_000/SelfAttention/k" - string_val: "decoder/block_000/layer_000/SelfAttention/k_slot_v" - string_val: "decoder/block_000/layer_000/SelfAttention/o" - string_val: "decoder/block_000/layer_000/SelfAttention/o_slot_v" - string_val: "decoder/block_000/layer_000/SelfAttention/q" - string_val: "decoder/block_000/layer_000/SelfAttention/q_slot_v" - string_val: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias" - string_val: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slot_v" - string_val: "decoder/block_000/layer_000/SelfAttention/v" - string_val: "decoder/block_000/layer_000/SelfAttention/v_slot_v" - string_val: "decoder/block_000/layer_000/rms_norm/scale" - string_val: "decoder/block_000/layer_000/rms_norm/scale_slot_v" - string_val: "decoder/block_000/layer_001/EncDecAttention/k" - string_val: "decoder/block_000/layer_001/EncDecAttention/k_slot_v" - string_val: "decoder/block_000/layer_001/EncDecAttention/o" - string_val: "decoder/block_000/layer_001/EncDecAttention/o_slot_v" - string_val: "decoder/block_000/layer_001/EncDecAttention/q" - string_val: "decoder/block_000/layer_001/EncDecAttention/q_slot_v" - string_val: "decoder/block_000/layer_001/EncDecAttention/v" - string_val: "decoder/block_000/layer_001/EncDecAttention/v_slot_v" - string_val: "decoder/block_000/layer_001/rms_norm/scale" - string_val: "decoder/block_000/layer_001/rms_norm/scale_slot_v" - string_val: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel" - string_val: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slot_v" - string_val: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel" - string_val: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slot_v" - string_val: "decoder/block_000/layer_002/DenseReluDense/wo/kernel" - string_val: "decoder/block_000/layer_002/DenseReluDense/wo/kernel_slot_v" - string_val: "decoder/block_000/layer_002/rms_norm/scale" - string_val: "decoder/block_000/layer_002/rms_norm/scale_slot_v" - string_val: "decoder/block_001/layer_000/SelfAttention/k" - string_val: "decoder/block_001/layer_000/SelfAttention/k_slot_v" - string_val: "decoder/block_001/layer_000/SelfAttention/o" - string_val: "decoder/block_001/layer_000/SelfAttention/o_slot_v" - string_val: "decoder/block_001/layer_000/SelfAttention/q" - string_val: "decoder/block_001/layer_000/SelfAttention/q_slot_v" - string_val: "decoder/block_001/layer_000/SelfAttention/v" - string_val: "decoder/block_001/layer_000/SelfAttention/v_slot_v" - string_val: "decoder/block_001/layer_000/rms_norm/scale" - string_val: "decoder/block_001/layer_000/rms_norm/scale_slot_v" - string_val: "decoder/block_001/layer_001/EncDecAttention/k" - string_val: "decoder/block_001/layer_001/EncDecAttention/k_slot_v" - string_val: "decoder/block_001/layer_001/EncDecAttention/o" - string_val: "decoder/block_001/layer_001/EncDecAttention/o_slot_v" - string_val: "decoder/block_001/layer_001/EncDecAttention/q" - string_val: "decoder/block_001/layer_001/EncDecAttention/q_slot_v" - string_val: "decoder/block_001/layer_001/EncDecAttention/v" - string_val: "decoder/block_001/layer_001/EncDecAttention/v_slot_v" - string_val: "decoder/block_001/layer_001/rms_norm/scale" - string_val: "decoder/block_001/layer_001/rms_norm/scale_slot_v" - string_val: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel" - string_val: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slot_v" - string_val: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel" - string_val: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slot_v" - string_val: "decoder/block_001/layer_002/DenseReluDense/wo/kernel" - string_val: "decoder/block_001/layer_002/DenseReluDense/wo/kernel_slot_v" - string_val: "decoder/block_001/layer_002/rms_norm/scale" - string_val: "decoder/block_001/layer_002/rms_norm/scale_slot_v" - string_val: "decoder/logits/kernel" - string_val: "decoder/logits/kernel_slot_v" - string_val: "decoder/rms_norm/scale" - string_val: "decoder/rms_norm/scale_slot_v" - string_val: "encoder/block_000/layer_000/SelfAttention/k" - string_val: "encoder/block_000/layer_000/SelfAttention/k_slot_v" - string_val: "encoder/block_000/layer_000/SelfAttention/o" - string_val: "encoder/block_000/layer_000/SelfAttention/o_slot_v" - string_val: "encoder/block_000/layer_000/SelfAttention/q" - string_val: "encoder/block_000/layer_000/SelfAttention/q_slot_v" - string_val: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias" - string_val: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias_slot_v" - string_val: "encoder/block_000/layer_000/SelfAttention/v" - string_val: "encoder/block_000/layer_000/SelfAttention/v_slot_v" - string_val: "encoder/block_000/layer_000/rms_norm/scale" - string_val: "encoder/block_000/layer_000/rms_norm/scale_slot_v" - string_val: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel" - string_val: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slot_v" - string_val: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel" - string_val: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slot_v" - string_val: "encoder/block_000/layer_001/DenseReluDense/wo/kernel" - string_val: "encoder/block_000/layer_001/DenseReluDense/wo/kernel_slot_v" - string_val: "encoder/block_000/layer_001/rms_norm/scale" - string_val: "encoder/block_000/layer_001/rms_norm/scale_slot_v" - string_val: "encoder/block_001/layer_000/SelfAttention/k" - string_val: "encoder/block_001/layer_000/SelfAttention/k_slot_v" - string_val: "encoder/block_001/layer_000/SelfAttention/o" - string_val: "encoder/block_001/layer_000/SelfAttention/o_slot_v" - string_val: "encoder/block_001/layer_000/SelfAttention/q" - string_val: "encoder/block_001/layer_000/SelfAttention/q_slot_v" - string_val: "encoder/block_001/layer_000/SelfAttention/v" - string_val: "encoder/block_001/layer_000/SelfAttention/v_slot_v" - string_val: "encoder/block_001/layer_000/rms_norm/scale" - string_val: "encoder/block_001/layer_000/rms_norm/scale_slot_v" - string_val: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel" - string_val: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slot_v" - string_val: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel" - string_val: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slot_v" - string_val: "encoder/block_001/layer_001/DenseReluDense/wo/kernel" - string_val: "encoder/block_001/layer_001/DenseReluDense/wo/kernel_slot_v" - string_val: "encoder/block_001/layer_001/rms_norm/scale" - string_val: "encoder/block_001/layer_001/rms_norm/scale_slot_v" - string_val: "encoder/rms_norm/scale" - string_val: "encoder/rms_norm/scale_slot_v" - string_val: "shared/embedding" - string_val: "shared/embedding_slot_v" - } - } - } -} -node { - name: "save/SaveV2_1/shape_and_slices" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 104 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - dim { - size: 104 - } - } - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - } - } - } -} -node { - name: "save/SaveV2_1" - op: "SaveV2" - input: "save/ShardedFilename_1" - input: "save/SaveV2_1/tensor_names" - input: "save/SaveV2_1/shape_and_slices" - input: "decoder/block_000/layer_000/SelfAttention/k" - input: "decoder/block_000/layer_000/SelfAttention/k_slot_v" - input: "decoder/block_000/layer_000/SelfAttention/o" - input: "decoder/block_000/layer_000/SelfAttention/o_slot_v" - input: "decoder/block_000/layer_000/SelfAttention/q" - input: "decoder/block_000/layer_000/SelfAttention/q_slot_v" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slot_v" - input: "decoder/block_000/layer_000/SelfAttention/v" - input: "decoder/block_000/layer_000/SelfAttention/v_slot_v" - input: "decoder/block_000/layer_000/rms_norm/scale" - input: "decoder/block_000/layer_000/rms_norm/scale_slot_v" - input: "decoder/block_000/layer_001/EncDecAttention/k" - input: "decoder/block_000/layer_001/EncDecAttention/k_slot_v" - input: "decoder/block_000/layer_001/EncDecAttention/o" - input: "decoder/block_000/layer_001/EncDecAttention/o_slot_v" - input: "decoder/block_000/layer_001/EncDecAttention/q" - input: "decoder/block_000/layer_001/EncDecAttention/q_slot_v" - input: "decoder/block_000/layer_001/EncDecAttention/v" - input: "decoder/block_000/layer_001/EncDecAttention/v_slot_v" - input: "decoder/block_000/layer_001/rms_norm/scale" - input: "decoder/block_000/layer_001/rms_norm/scale_slot_v" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slot_v" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slot_v" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel_slot_v" - input: "decoder/block_000/layer_002/rms_norm/scale" - input: "decoder/block_000/layer_002/rms_norm/scale_slot_v" - input: "decoder/block_001/layer_000/SelfAttention/k" - input: "decoder/block_001/layer_000/SelfAttention/k_slot_v" - input: "decoder/block_001/layer_000/SelfAttention/o" - input: "decoder/block_001/layer_000/SelfAttention/o_slot_v" - input: "decoder/block_001/layer_000/SelfAttention/q" - input: "decoder/block_001/layer_000/SelfAttention/q_slot_v" - input: "decoder/block_001/layer_000/SelfAttention/v" - input: "decoder/block_001/layer_000/SelfAttention/v_slot_v" - input: "decoder/block_001/layer_000/rms_norm/scale" - input: "decoder/block_001/layer_000/rms_norm/scale_slot_v" - input: "decoder/block_001/layer_001/EncDecAttention/k" - input: "decoder/block_001/layer_001/EncDecAttention/k_slot_v" - input: "decoder/block_001/layer_001/EncDecAttention/o" - input: "decoder/block_001/layer_001/EncDecAttention/o_slot_v" - input: "decoder/block_001/layer_001/EncDecAttention/q" - input: "decoder/block_001/layer_001/EncDecAttention/q_slot_v" - input: "decoder/block_001/layer_001/EncDecAttention/v" - input: "decoder/block_001/layer_001/EncDecAttention/v_slot_v" - input: "decoder/block_001/layer_001/rms_norm/scale" - input: "decoder/block_001/layer_001/rms_norm/scale_slot_v" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slot_v" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slot_v" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel_slot_v" - input: "decoder/block_001/layer_002/rms_norm/scale" - input: "decoder/block_001/layer_002/rms_norm/scale_slot_v" - input: "decoder/logits/kernel" - input: "decoder/logits/kernel_slot_v" - input: "decoder/rms_norm/scale" - input: "decoder/rms_norm/scale_slot_v" - input: "encoder/block_000/layer_000/SelfAttention/k" - input: "encoder/block_000/layer_000/SelfAttention/k_slot_v" - input: "encoder/block_000/layer_000/SelfAttention/o" - input: "encoder/block_000/layer_000/SelfAttention/o_slot_v" - input: "encoder/block_000/layer_000/SelfAttention/q" - input: "encoder/block_000/layer_000/SelfAttention/q_slot_v" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias_slot_v" - input: "encoder/block_000/layer_000/SelfAttention/v" - input: "encoder/block_000/layer_000/SelfAttention/v_slot_v" - input: "encoder/block_000/layer_000/rms_norm/scale" - input: "encoder/block_000/layer_000/rms_norm/scale_slot_v" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slot_v" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slot_v" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel_slot_v" - input: "encoder/block_000/layer_001/rms_norm/scale" - input: "encoder/block_000/layer_001/rms_norm/scale_slot_v" - input: "encoder/block_001/layer_000/SelfAttention/k" - input: "encoder/block_001/layer_000/SelfAttention/k_slot_v" - input: "encoder/block_001/layer_000/SelfAttention/o" - input: "encoder/block_001/layer_000/SelfAttention/o_slot_v" - input: "encoder/block_001/layer_000/SelfAttention/q" - input: "encoder/block_001/layer_000/SelfAttention/q_slot_v" - input: "encoder/block_001/layer_000/SelfAttention/v" - input: "encoder/block_001/layer_000/SelfAttention/v_slot_v" - input: "encoder/block_001/layer_000/rms_norm/scale" - input: "encoder/block_001/layer_000/rms_norm/scale_slot_v" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slot_v" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slot_v" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel_slot_v" - input: "encoder/block_001/layer_001/rms_norm/scale" - input: "encoder/block_001/layer_001/rms_norm/scale_slot_v" - input: "encoder/rms_norm/scale" - input: "encoder/rms_norm/scale_slot_v" - input: "shared/embedding" - input: "shared/embedding_slot_v" - device: "/device:CPU:0" - attr { - key: "dtypes" - value { - list { - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - } - } - } -} -node { - name: "save/control_dependency_1" - op: "Identity" - input: "save/ShardedFilename_1" - input: "^save/SaveV2_1" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_STRING - } - } - attr { - key: "_class" - value { - list { - s: "loc:@save/ShardedFilename_1" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "save/MergeV2Checkpoints/checkpoint_prefixes" - op: "Pack" - input: "save/ShardedFilename" - input: "save/ShardedFilename_1" - input: "^save/control_dependency" - input: "^save/control_dependency_1" - device: "/device:CPU:0" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_STRING - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "axis" - value { - i: 0 - } - } -} -node { - name: "save/MergeV2Checkpoints" - op: "MergeV2Checkpoints" - input: "save/MergeV2Checkpoints/checkpoint_prefixes" - input: "save/Const" - device: "/device:CPU:0" - attr { - key: "delete_old_dirs" - value { - b: true - } - } -} -node { - name: "save/Identity" - op: "Identity" - input: "save/Const" - input: "^save/MergeV2Checkpoints" - input: "^save/control_dependency" - input: "^save/control_dependency_1" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_STRING - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "save/RestoreV2/tensor_names" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - dim { - size: 1 - } - } - string_val: "global_step" - } - } - } -} -node { - name: "save/RestoreV2/shape_and_slices" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - dim { - size: 1 - } - } - string_val: "" - } - } - } -} -node { - name: "save/RestoreV2" - op: "RestoreV2" - input: "save/Const" - input: "save/RestoreV2/tensor_names" - input: "save/RestoreV2/shape_and_slices" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - unknown_rank: true - } - } - } - } - attr { - key: "dtypes" - value { - list { - type: DT_INT64 - } - } - } -} -node { - name: "save/Identity_1" - op: "Identity" - input: "save/RestoreV2" - attr { - key: "T" - value { - type: DT_INT64 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - unknown_rank: true - } - } - } - } -} -node { - name: "save/AssignVariableOp" - op: "AssignVariableOp" - input: "global_step" - input: "save/Identity_1" - attr { - key: "dtype" - value { - type: DT_INT64 - } - } -} -node { - name: "save/restore_shard" - op: "NoOp" - input: "^save/AssignVariableOp" -} -node { - name: "save/RestoreV2_1/tensor_names" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 104 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - dim { - size: 104 - } - } - string_val: "decoder/block_000/layer_000/SelfAttention/k" - string_val: "decoder/block_000/layer_000/SelfAttention/k_slot_v" - string_val: "decoder/block_000/layer_000/SelfAttention/o" - string_val: "decoder/block_000/layer_000/SelfAttention/o_slot_v" - string_val: "decoder/block_000/layer_000/SelfAttention/q" - string_val: "decoder/block_000/layer_000/SelfAttention/q_slot_v" - string_val: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias" - string_val: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slot_v" - string_val: "decoder/block_000/layer_000/SelfAttention/v" - string_val: "decoder/block_000/layer_000/SelfAttention/v_slot_v" - string_val: "decoder/block_000/layer_000/rms_norm/scale" - string_val: "decoder/block_000/layer_000/rms_norm/scale_slot_v" - string_val: "decoder/block_000/layer_001/EncDecAttention/k" - string_val: "decoder/block_000/layer_001/EncDecAttention/k_slot_v" - string_val: "decoder/block_000/layer_001/EncDecAttention/o" - string_val: "decoder/block_000/layer_001/EncDecAttention/o_slot_v" - string_val: "decoder/block_000/layer_001/EncDecAttention/q" - string_val: "decoder/block_000/layer_001/EncDecAttention/q_slot_v" - string_val: "decoder/block_000/layer_001/EncDecAttention/v" - string_val: "decoder/block_000/layer_001/EncDecAttention/v_slot_v" - string_val: "decoder/block_000/layer_001/rms_norm/scale" - string_val: "decoder/block_000/layer_001/rms_norm/scale_slot_v" - string_val: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel" - string_val: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slot_v" - string_val: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel" - string_val: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slot_v" - string_val: "decoder/block_000/layer_002/DenseReluDense/wo/kernel" - string_val: "decoder/block_000/layer_002/DenseReluDense/wo/kernel_slot_v" - string_val: "decoder/block_000/layer_002/rms_norm/scale" - string_val: "decoder/block_000/layer_002/rms_norm/scale_slot_v" - string_val: "decoder/block_001/layer_000/SelfAttention/k" - string_val: "decoder/block_001/layer_000/SelfAttention/k_slot_v" - string_val: "decoder/block_001/layer_000/SelfAttention/o" - string_val: "decoder/block_001/layer_000/SelfAttention/o_slot_v" - string_val: "decoder/block_001/layer_000/SelfAttention/q" - string_val: "decoder/block_001/layer_000/SelfAttention/q_slot_v" - string_val: "decoder/block_001/layer_000/SelfAttention/v" - string_val: "decoder/block_001/layer_000/SelfAttention/v_slot_v" - string_val: "decoder/block_001/layer_000/rms_norm/scale" - string_val: "decoder/block_001/layer_000/rms_norm/scale_slot_v" - string_val: "decoder/block_001/layer_001/EncDecAttention/k" - string_val: "decoder/block_001/layer_001/EncDecAttention/k_slot_v" - string_val: "decoder/block_001/layer_001/EncDecAttention/o" - string_val: "decoder/block_001/layer_001/EncDecAttention/o_slot_v" - string_val: "decoder/block_001/layer_001/EncDecAttention/q" - string_val: "decoder/block_001/layer_001/EncDecAttention/q_slot_v" - string_val: "decoder/block_001/layer_001/EncDecAttention/v" - string_val: "decoder/block_001/layer_001/EncDecAttention/v_slot_v" - string_val: "decoder/block_001/layer_001/rms_norm/scale" - string_val: "decoder/block_001/layer_001/rms_norm/scale_slot_v" - string_val: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel" - string_val: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slot_v" - string_val: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel" - string_val: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slot_v" - string_val: "decoder/block_001/layer_002/DenseReluDense/wo/kernel" - string_val: "decoder/block_001/layer_002/DenseReluDense/wo/kernel_slot_v" - string_val: "decoder/block_001/layer_002/rms_norm/scale" - string_val: "decoder/block_001/layer_002/rms_norm/scale_slot_v" - string_val: "decoder/logits/kernel" - string_val: "decoder/logits/kernel_slot_v" - string_val: "decoder/rms_norm/scale" - string_val: "decoder/rms_norm/scale_slot_v" - string_val: "encoder/block_000/layer_000/SelfAttention/k" - string_val: "encoder/block_000/layer_000/SelfAttention/k_slot_v" - string_val: "encoder/block_000/layer_000/SelfAttention/o" - string_val: "encoder/block_000/layer_000/SelfAttention/o_slot_v" - string_val: "encoder/block_000/layer_000/SelfAttention/q" - string_val: "encoder/block_000/layer_000/SelfAttention/q_slot_v" - string_val: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias" - string_val: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias_slot_v" - string_val: "encoder/block_000/layer_000/SelfAttention/v" - string_val: "encoder/block_000/layer_000/SelfAttention/v_slot_v" - string_val: "encoder/block_000/layer_000/rms_norm/scale" - string_val: "encoder/block_000/layer_000/rms_norm/scale_slot_v" - string_val: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel" - string_val: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slot_v" - string_val: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel" - string_val: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slot_v" - string_val: "encoder/block_000/layer_001/DenseReluDense/wo/kernel" - string_val: "encoder/block_000/layer_001/DenseReluDense/wo/kernel_slot_v" - string_val: "encoder/block_000/layer_001/rms_norm/scale" - string_val: "encoder/block_000/layer_001/rms_norm/scale_slot_v" - string_val: "encoder/block_001/layer_000/SelfAttention/k" - string_val: "encoder/block_001/layer_000/SelfAttention/k_slot_v" - string_val: "encoder/block_001/layer_000/SelfAttention/o" - string_val: "encoder/block_001/layer_000/SelfAttention/o_slot_v" - string_val: "encoder/block_001/layer_000/SelfAttention/q" - string_val: "encoder/block_001/layer_000/SelfAttention/q_slot_v" - string_val: "encoder/block_001/layer_000/SelfAttention/v" - string_val: "encoder/block_001/layer_000/SelfAttention/v_slot_v" - string_val: "encoder/block_001/layer_000/rms_norm/scale" - string_val: "encoder/block_001/layer_000/rms_norm/scale_slot_v" - string_val: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel" - string_val: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slot_v" - string_val: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel" - string_val: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slot_v" - string_val: "encoder/block_001/layer_001/DenseReluDense/wo/kernel" - string_val: "encoder/block_001/layer_001/DenseReluDense/wo/kernel_slot_v" - string_val: "encoder/block_001/layer_001/rms_norm/scale" - string_val: "encoder/block_001/layer_001/rms_norm/scale_slot_v" - string_val: "encoder/rms_norm/scale" - string_val: "encoder/rms_norm/scale_slot_v" - string_val: "shared/embedding" - string_val: "shared/embedding_slot_v" - } - } - } -} -node { - name: "save/RestoreV2_1/shape_and_slices" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 104 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - dim { - size: 104 - } - } - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - string_val: "" - } - } - } -} -node { - name: "save/RestoreV2_1" - op: "RestoreV2" - input: "save/Const" - input: "save/RestoreV2_1/tensor_names" - input: "save/RestoreV2_1/shape_and_slices" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - } - } - } - attr { - key: "dtypes" - value { - list { - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - type: DT_BFLOAT16 - type: DT_FLOAT - } - } - } -} -node { - name: "save/Assign" - op: "Assign" - input: "decoder/block_000/layer_000/SelfAttention/k" - input: "save/RestoreV2_1" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_1" - op: "Assign" - input: "decoder/block_000/layer_000/SelfAttention/k_slot_v" - input: "save/RestoreV2_1:1" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/k_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_2" - op: "Assign" - input: "decoder/block_000/layer_000/SelfAttention/o" - input: "save/RestoreV2_1:2" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_3" - op: "Assign" - input: "decoder/block_000/layer_000/SelfAttention/o_slot_v" - input: "save/RestoreV2_1:3" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/o_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_4" - op: "Assign" - input: "decoder/block_000/layer_000/SelfAttention/q" - input: "save/RestoreV2_1:4" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_5" - op: "Assign" - input: "decoder/block_000/layer_000/SelfAttention/q_slot_v" - input: "save/RestoreV2_1:5" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/q_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_6" - op: "Assign" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias" - input: "save/RestoreV2_1:6" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/relative_attention_bias" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_7" - op: "Assign" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slot_v" - input: "save/RestoreV2_1:7" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_8" - op: "Assign" - input: "decoder/block_000/layer_000/SelfAttention/v" - input: "save/RestoreV2_1:8" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_9" - op: "Assign" - input: "decoder/block_000/layer_000/SelfAttention/v_slot_v" - input: "save/RestoreV2_1:9" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/v_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_10" - op: "Assign" - input: "decoder/block_000/layer_000/rms_norm/scale" - input: "save/RestoreV2_1:10" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_11" - op: "Assign" - input: "decoder/block_000/layer_000/rms_norm/scale_slot_v" - input: "save/RestoreV2_1:11" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_12" - op: "Assign" - input: "decoder/block_000/layer_001/EncDecAttention/k" - input: "save/RestoreV2_1:12" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_13" - op: "Assign" - input: "decoder/block_000/layer_001/EncDecAttention/k_slot_v" - input: "save/RestoreV2_1:13" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/k_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_14" - op: "Assign" - input: "decoder/block_000/layer_001/EncDecAttention/o" - input: "save/RestoreV2_1:14" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_15" - op: "Assign" - input: "decoder/block_000/layer_001/EncDecAttention/o_slot_v" - input: "save/RestoreV2_1:15" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/o_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_16" - op: "Assign" - input: "decoder/block_000/layer_001/EncDecAttention/q" - input: "save/RestoreV2_1:16" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_17" - op: "Assign" - input: "decoder/block_000/layer_001/EncDecAttention/q_slot_v" - input: "save/RestoreV2_1:17" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/q_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_18" - op: "Assign" - input: "decoder/block_000/layer_001/EncDecAttention/v" - input: "save/RestoreV2_1:18" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_19" - op: "Assign" - input: "decoder/block_000/layer_001/EncDecAttention/v_slot_v" - input: "save/RestoreV2_1:19" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/v_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_20" - op: "Assign" - input: "decoder/block_000/layer_001/rms_norm/scale" - input: "save/RestoreV2_1:20" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_21" - op: "Assign" - input: "decoder/block_000/layer_001/rms_norm/scale_slot_v" - input: "save/RestoreV2_1:21" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_22" - op: "Assign" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel" - input: "save/RestoreV2_1:22" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wi_0/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_23" - op: "Assign" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slot_v" - input: "save/RestoreV2_1:23" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_24" - op: "Assign" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel" - input: "save/RestoreV2_1:24" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wi_1/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_25" - op: "Assign" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slot_v" - input: "save/RestoreV2_1:25" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_26" - op: "Assign" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel" - input: "save/RestoreV2_1:26" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wo/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_27" - op: "Assign" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel_slot_v" - input: "save/RestoreV2_1:27" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wo/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_28" - op: "Assign" - input: "decoder/block_000/layer_002/rms_norm/scale" - input: "save/RestoreV2_1:28" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_29" - op: "Assign" - input: "decoder/block_000/layer_002/rms_norm/scale_slot_v" - input: "save/RestoreV2_1:29" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_30" - op: "Assign" - input: "decoder/block_001/layer_000/SelfAttention/k" - input: "save/RestoreV2_1:30" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_31" - op: "Assign" - input: "decoder/block_001/layer_000/SelfAttention/k_slot_v" - input: "save/RestoreV2_1:31" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/k_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_32" - op: "Assign" - input: "decoder/block_001/layer_000/SelfAttention/o" - input: "save/RestoreV2_1:32" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_33" - op: "Assign" - input: "decoder/block_001/layer_000/SelfAttention/o_slot_v" - input: "save/RestoreV2_1:33" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/o_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_34" - op: "Assign" - input: "decoder/block_001/layer_000/SelfAttention/q" - input: "save/RestoreV2_1:34" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_35" - op: "Assign" - input: "decoder/block_001/layer_000/SelfAttention/q_slot_v" - input: "save/RestoreV2_1:35" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/q_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_36" - op: "Assign" - input: "decoder/block_001/layer_000/SelfAttention/v" - input: "save/RestoreV2_1:36" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_37" - op: "Assign" - input: "decoder/block_001/layer_000/SelfAttention/v_slot_v" - input: "save/RestoreV2_1:37" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/v_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_38" - op: "Assign" - input: "decoder/block_001/layer_000/rms_norm/scale" - input: "save/RestoreV2_1:38" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_39" - op: "Assign" - input: "decoder/block_001/layer_000/rms_norm/scale_slot_v" - input: "save/RestoreV2_1:39" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_40" - op: "Assign" - input: "decoder/block_001/layer_001/EncDecAttention/k" - input: "save/RestoreV2_1:40" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_41" - op: "Assign" - input: "decoder/block_001/layer_001/EncDecAttention/k_slot_v" - input: "save/RestoreV2_1:41" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/k_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_42" - op: "Assign" - input: "decoder/block_001/layer_001/EncDecAttention/o" - input: "save/RestoreV2_1:42" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_43" - op: "Assign" - input: "decoder/block_001/layer_001/EncDecAttention/o_slot_v" - input: "save/RestoreV2_1:43" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/o_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_44" - op: "Assign" - input: "decoder/block_001/layer_001/EncDecAttention/q" - input: "save/RestoreV2_1:44" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_45" - op: "Assign" - input: "decoder/block_001/layer_001/EncDecAttention/q_slot_v" - input: "save/RestoreV2_1:45" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/q_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_46" - op: "Assign" - input: "decoder/block_001/layer_001/EncDecAttention/v" - input: "save/RestoreV2_1:46" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_47" - op: "Assign" - input: "decoder/block_001/layer_001/EncDecAttention/v_slot_v" - input: "save/RestoreV2_1:47" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/v_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_48" - op: "Assign" - input: "decoder/block_001/layer_001/rms_norm/scale" - input: "save/RestoreV2_1:48" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_49" - op: "Assign" - input: "decoder/block_001/layer_001/rms_norm/scale_slot_v" - input: "save/RestoreV2_1:49" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_50" - op: "Assign" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel" - input: "save/RestoreV2_1:50" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wi_0/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_51" - op: "Assign" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slot_v" - input: "save/RestoreV2_1:51" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_52" - op: "Assign" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel" - input: "save/RestoreV2_1:52" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wi_1/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_53" - op: "Assign" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slot_v" - input: "save/RestoreV2_1:53" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_54" - op: "Assign" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel" - input: "save/RestoreV2_1:54" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wo/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_55" - op: "Assign" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel_slot_v" - input: "save/RestoreV2_1:55" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wo/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_56" - op: "Assign" - input: "decoder/block_001/layer_002/rms_norm/scale" - input: "save/RestoreV2_1:56" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_57" - op: "Assign" - input: "decoder/block_001/layer_002/rms_norm/scale_slot_v" - input: "save/RestoreV2_1:57" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_58" - op: "Assign" - input: "decoder/logits/kernel" - input: "save/RestoreV2_1:58" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/logits/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 32128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_59" - op: "Assign" - input: "decoder/logits/kernel_slot_v" - input: "save/RestoreV2_1:59" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/logits/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 32128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_60" - op: "Assign" - input: "decoder/rms_norm/scale" - input: "save/RestoreV2_1:60" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_61" - op: "Assign" - input: "decoder/rms_norm/scale_slot_v" - input: "save/RestoreV2_1:61" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decoder/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_62" - op: "Assign" - input: "encoder/block_000/layer_000/SelfAttention/k" - input: "save/RestoreV2_1:62" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_63" - op: "Assign" - input: "encoder/block_000/layer_000/SelfAttention/k_slot_v" - input: "save/RestoreV2_1:63" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/k_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_64" - op: "Assign" - input: "encoder/block_000/layer_000/SelfAttention/o" - input: "save/RestoreV2_1:64" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_65" - op: "Assign" - input: "encoder/block_000/layer_000/SelfAttention/o_slot_v" - input: "save/RestoreV2_1:65" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/o_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_66" - op: "Assign" - input: "encoder/block_000/layer_000/SelfAttention/q" - input: "save/RestoreV2_1:66" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_67" - op: "Assign" - input: "encoder/block_000/layer_000/SelfAttention/q_slot_v" - input: "save/RestoreV2_1:67" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/q_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_68" - op: "Assign" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias" - input: "save/RestoreV2_1:68" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/relative_attention_bias" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_69" - op: "Assign" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias_slot_v" - input: "save/RestoreV2_1:69" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/relative_attention_bias_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_70" - op: "Assign" - input: "encoder/block_000/layer_000/SelfAttention/v" - input: "save/RestoreV2_1:70" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_71" - op: "Assign" - input: "encoder/block_000/layer_000/SelfAttention/v_slot_v" - input: "save/RestoreV2_1:71" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/v_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_72" - op: "Assign" - input: "encoder/block_000/layer_000/rms_norm/scale" - input: "save/RestoreV2_1:72" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_73" - op: "Assign" - input: "encoder/block_000/layer_000/rms_norm/scale_slot_v" - input: "save/RestoreV2_1:73" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_74" - op: "Assign" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel" - input: "save/RestoreV2_1:74" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wi_0/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_75" - op: "Assign" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slot_v" - input: "save/RestoreV2_1:75" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_76" - op: "Assign" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel" - input: "save/RestoreV2_1:76" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wi_1/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_77" - op: "Assign" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slot_v" - input: "save/RestoreV2_1:77" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_78" - op: "Assign" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel" - input: "save/RestoreV2_1:78" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wo/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_79" - op: "Assign" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel_slot_v" - input: "save/RestoreV2_1:79" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wo/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_80" - op: "Assign" - input: "encoder/block_000/layer_001/rms_norm/scale" - input: "save/RestoreV2_1:80" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_81" - op: "Assign" - input: "encoder/block_000/layer_001/rms_norm/scale_slot_v" - input: "save/RestoreV2_1:81" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_82" - op: "Assign" - input: "encoder/block_001/layer_000/SelfAttention/k" - input: "save/RestoreV2_1:82" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_83" - op: "Assign" - input: "encoder/block_001/layer_000/SelfAttention/k_slot_v" - input: "save/RestoreV2_1:83" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/k_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_84" - op: "Assign" - input: "encoder/block_001/layer_000/SelfAttention/o" - input: "save/RestoreV2_1:84" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_85" - op: "Assign" - input: "encoder/block_001/layer_000/SelfAttention/o_slot_v" - input: "save/RestoreV2_1:85" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/o_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_86" - op: "Assign" - input: "encoder/block_001/layer_000/SelfAttention/q" - input: "save/RestoreV2_1:86" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_87" - op: "Assign" - input: "encoder/block_001/layer_000/SelfAttention/q_slot_v" - input: "save/RestoreV2_1:87" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/q_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_88" - op: "Assign" - input: "encoder/block_001/layer_000/SelfAttention/v" - input: "save/RestoreV2_1:88" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_89" - op: "Assign" - input: "encoder/block_001/layer_000/SelfAttention/v_slot_v" - input: "save/RestoreV2_1:89" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/v_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 128 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_90" - op: "Assign" - input: "encoder/block_001/layer_000/rms_norm/scale" - input: "save/RestoreV2_1:90" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_91" - op: "Assign" - input: "encoder/block_001/layer_000/rms_norm/scale_slot_v" - input: "save/RestoreV2_1:91" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_92" - op: "Assign" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel" - input: "save/RestoreV2_1:92" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wi_0/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_93" - op: "Assign" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slot_v" - input: "save/RestoreV2_1:93" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_94" - op: "Assign" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel" - input: "save/RestoreV2_1:94" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wi_1/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_95" - op: "Assign" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slot_v" - input: "save/RestoreV2_1:95" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - dim { - size: 64 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_96" - op: "Assign" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel" - input: "save/RestoreV2_1:96" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wo/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_97" - op: "Assign" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel_slot_v" - input: "save/RestoreV2_1:97" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wo/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_98" - op: "Assign" - input: "encoder/block_001/layer_001/rms_norm/scale" - input: "save/RestoreV2_1:98" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_99" - op: "Assign" - input: "encoder/block_001/layer_001/rms_norm/scale_slot_v" - input: "save/RestoreV2_1:99" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_100" - op: "Assign" - input: "encoder/rms_norm/scale" - input: "save/RestoreV2_1:100" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_101" - op: "Assign" - input: "encoder/rms_norm/scale_slot_v" - input: "save/RestoreV2_1:101" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@encoder/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_102" - op: "Assign" - input: "shared/embedding" - input: "save/RestoreV2_1:102" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@shared/embedding" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/Assign_103" - op: "Assign" - input: "shared/embedding_slot_v" - input: "save/RestoreV2_1:103" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@shared/embedding_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 32128 - } - dim { - size: 32 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } -} -node { - name: "save/restore_shard_1" - op: "NoOp" - input: "^save/Assign" - input: "^save/Assign_1" - input: "^save/Assign_10" - input: "^save/Assign_100" - input: "^save/Assign_101" - input: "^save/Assign_102" - input: "^save/Assign_103" - input: "^save/Assign_11" - input: "^save/Assign_12" - input: "^save/Assign_13" - input: "^save/Assign_14" - input: "^save/Assign_15" - input: "^save/Assign_16" - input: "^save/Assign_17" - input: "^save/Assign_18" - input: "^save/Assign_19" - input: "^save/Assign_2" - input: "^save/Assign_20" - input: "^save/Assign_21" - input: "^save/Assign_22" - input: "^save/Assign_23" - input: "^save/Assign_24" - input: "^save/Assign_25" - input: "^save/Assign_26" - input: "^save/Assign_27" - input: "^save/Assign_28" - input: "^save/Assign_29" - input: "^save/Assign_3" - input: "^save/Assign_30" - input: "^save/Assign_31" - input: "^save/Assign_32" - input: "^save/Assign_33" - input: "^save/Assign_34" - input: "^save/Assign_35" - input: "^save/Assign_36" - input: "^save/Assign_37" - input: "^save/Assign_38" - input: "^save/Assign_39" - input: "^save/Assign_4" - input: "^save/Assign_40" - input: "^save/Assign_41" - input: "^save/Assign_42" - input: "^save/Assign_43" - input: "^save/Assign_44" - input: "^save/Assign_45" - input: "^save/Assign_46" - input: "^save/Assign_47" - input: "^save/Assign_48" - input: "^save/Assign_49" - input: "^save/Assign_5" - input: "^save/Assign_50" - input: "^save/Assign_51" - input: "^save/Assign_52" - input: "^save/Assign_53" - input: "^save/Assign_54" - input: "^save/Assign_55" - input: "^save/Assign_56" - input: "^save/Assign_57" - input: "^save/Assign_58" - input: "^save/Assign_59" - input: "^save/Assign_6" - input: "^save/Assign_60" - input: "^save/Assign_61" - input: "^save/Assign_62" - input: "^save/Assign_63" - input: "^save/Assign_64" - input: "^save/Assign_65" - input: "^save/Assign_66" - input: "^save/Assign_67" - input: "^save/Assign_68" - input: "^save/Assign_69" - input: "^save/Assign_7" - input: "^save/Assign_70" - input: "^save/Assign_71" - input: "^save/Assign_72" - input: "^save/Assign_73" - input: "^save/Assign_74" - input: "^save/Assign_75" - input: "^save/Assign_76" - input: "^save/Assign_77" - input: "^save/Assign_78" - input: "^save/Assign_79" - input: "^save/Assign_8" - input: "^save/Assign_80" - input: "^save/Assign_81" - input: "^save/Assign_82" - input: "^save/Assign_83" - input: "^save/Assign_84" - input: "^save/Assign_85" - input: "^save/Assign_86" - input: "^save/Assign_87" - input: "^save/Assign_88" - input: "^save/Assign_89" - input: "^save/Assign_9" - input: "^save/Assign_90" - input: "^save/Assign_91" - input: "^save/Assign_92" - input: "^save/Assign_93" - input: "^save/Assign_94" - input: "^save/Assign_95" - input: "^save/Assign_96" - input: "^save/Assign_97" - input: "^save/Assign_98" - input: "^save/Assign_99" - device: "/device:CPU:0" -} -node { - name: "save/restore_all/NoOp" - op: "NoOp" - input: "^save/restore_shard" -} -node { - name: "save/restore_all/NoOp_1" - op: "NoOp" - input: "^save/restore_shard_1" - device: "/device:CPU:0" -} -node { - name: "save/restore_all" - op: "NoOp" - input: "^save/restore_all/NoOp" - input: "^save/restore_all/NoOp_1" -} -node { - name: "group_deps_1/NoOp" - op: "NoOp" - input: "^decoder/block_000/layer_000/SelfAttention/k_slot_v_1/group_deps" - input: "^decoder/block_000/layer_000/SelfAttention/o_slot_v_1/group_deps" - input: "^decoder/block_000/layer_000/SelfAttention/q_slot_v_1/group_deps" - input: "^decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slot_v_1/group_deps" - input: "^decoder/block_000/layer_000/SelfAttention/v_slot_v_1/group_deps" - input: "^decoder/block_000/layer_000/rms_norm/scale_slot_v_1/group_deps" - input: "^decoder/block_000/layer_001/EncDecAttention/k_slot_v_1/group_deps" - input: "^decoder/block_000/layer_001/EncDecAttention/o_slot_v_1/group_deps" - input: "^decoder/block_000/layer_001/EncDecAttention/q_slot_v_1/group_deps" - input: "^decoder/block_000/layer_001/EncDecAttention/v_slot_v_1/group_deps" - input: "^decoder/block_000/layer_001/rms_norm/scale_slot_v_1/group_deps" - input: "^decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slot_v_1/group_deps" - input: "^decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slot_v_1/group_deps" - input: "^decoder/block_000/layer_002/DenseReluDense/wo/kernel_slot_v_1/group_deps" - input: "^decoder/block_000/layer_002/rms_norm/scale_slot_v_1/group_deps" - input: "^decoder/block_001/layer_000/SelfAttention/k_slot_v_1/group_deps" - input: "^decoder/block_001/layer_000/SelfAttention/o_slot_v_1/group_deps" - input: "^decoder/block_001/layer_000/SelfAttention/q_slot_v_1/group_deps" - input: "^decoder/block_001/layer_000/SelfAttention/v_slot_v_1/group_deps" - input: "^decoder/block_001/layer_000/rms_norm/scale_slot_v_1/group_deps" - input: "^decoder/block_001/layer_001/EncDecAttention/k_slot_v_1/group_deps" - input: "^decoder/block_001/layer_001/EncDecAttention/o_slot_v_1/group_deps" - input: "^decoder/block_001/layer_001/EncDecAttention/q_slot_v_1/group_deps" - input: "^decoder/block_001/layer_001/EncDecAttention/v_slot_v_1/group_deps" - input: "^decoder/block_001/layer_001/rms_norm/scale_slot_v_1/group_deps" - input: "^decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slot_v_1/group_deps" - input: "^decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slot_v_1/group_deps" - input: "^decoder/block_001/layer_002/DenseReluDense/wo/kernel_slot_v_1/group_deps" - input: "^decoder/block_001/layer_002/rms_norm/scale_slot_v_1/group_deps" - input: "^decoder/logits/kernel_slot_v_1/group_deps" - input: "^decoder/rms_norm/scale_slot_v_1/group_deps" - input: "^encoder/block_000/layer_000/SelfAttention/k_slot_v_1/group_deps" - input: "^encoder/block_000/layer_000/SelfAttention/o_slot_v_1/group_deps" - input: "^encoder/block_000/layer_000/SelfAttention/q_slot_v_1/group_deps" - input: "^encoder/block_000/layer_000/SelfAttention/relative_attention_bias_slot_v_1/group_deps" - input: "^encoder/block_000/layer_000/SelfAttention/v_slot_v_1/group_deps" - input: "^encoder/block_000/layer_000/rms_norm/scale_slot_v_1/group_deps" - input: "^encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slot_v_1/group_deps" - input: "^encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slot_v_1/group_deps" - input: "^encoder/block_000/layer_001/DenseReluDense/wo/kernel_slot_v_1/group_deps" - input: "^encoder/block_000/layer_001/rms_norm/scale_slot_v_1/group_deps" - input: "^encoder/block_001/layer_000/SelfAttention/k_slot_v_1/group_deps" - input: "^encoder/block_001/layer_000/SelfAttention/o_slot_v_1/group_deps" - input: "^encoder/block_001/layer_000/SelfAttention/q_slot_v_1/group_deps" - input: "^encoder/block_001/layer_000/SelfAttention/v_slot_v_1/group_deps" - input: "^encoder/block_001/layer_000/rms_norm/scale_slot_v_1/group_deps" - input: "^encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slot_v_1/group_deps" - input: "^encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slot_v_1/group_deps" - input: "^encoder/block_001/layer_001/DenseReluDense/wo/kernel_slot_v_1/group_deps" - input: "^encoder/block_001/layer_001/rms_norm/scale_slot_v_1/group_deps" - input: "^encoder/rms_norm/scale_slot_v_1/group_deps" - input: "^shared/embedding_slot_v_1/group_deps" -} -node { - name: "group_deps_1/NoOp_1" - op: "NoOp" - input: "^decoder/block_000/layer_000/SelfAttention/k_1/Assign" - input: "^decoder/block_000/layer_000/SelfAttention/o_1/Assign" - input: "^decoder/block_000/layer_000/SelfAttention/q_1/Assign" - input: "^decoder/block_000/layer_000/SelfAttention/relative_attention_bias_1/Assign" - input: "^decoder/block_000/layer_000/SelfAttention/v_1/Assign" - input: "^decoder/block_000/layer_000/rms_norm/scale_1/Assign" - input: "^decoder/block_000/layer_001/EncDecAttention/k_1/Assign" - input: "^decoder/block_000/layer_001/EncDecAttention/o_1/Assign" - input: "^decoder/block_000/layer_001/EncDecAttention/q_1/Assign" - input: "^decoder/block_000/layer_001/EncDecAttention/v_1/Assign" - input: "^decoder/block_000/layer_001/rms_norm/scale_1/Assign" - input: "^decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_1/Assign" - input: "^decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_1/Assign" - input: "^decoder/block_000/layer_002/DenseReluDense/wo/kernel_1/Assign" - input: "^decoder/block_000/layer_002/rms_norm/scale_1/Assign" - input: "^decoder/block_001/layer_000/SelfAttention/k_1/Assign" - input: "^decoder/block_001/layer_000/SelfAttention/o_1/Assign" - input: "^decoder/block_001/layer_000/SelfAttention/q_1/Assign" - input: "^decoder/block_001/layer_000/SelfAttention/v_1/Assign" - input: "^decoder/block_001/layer_000/rms_norm/scale_1/Assign" - input: "^decoder/block_001/layer_001/EncDecAttention/k_1/Assign" - input: "^decoder/block_001/layer_001/EncDecAttention/o_1/Assign" - input: "^decoder/block_001/layer_001/EncDecAttention/q_1/Assign" - input: "^decoder/block_001/layer_001/EncDecAttention/v_1/Assign" - input: "^decoder/block_001/layer_001/rms_norm/scale_1/Assign" - input: "^decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_1/Assign" - input: "^decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_1/Assign" - input: "^decoder/block_001/layer_002/DenseReluDense/wo/kernel_1/Assign" - input: "^decoder/block_001/layer_002/rms_norm/scale_1/Assign" - input: "^decoder/logits/kernel_1/Assign" - input: "^decoder/rms_norm/scale_1/Assign" - input: "^encoder/block_000/layer_000/SelfAttention/k_1/Assign" - input: "^encoder/block_000/layer_000/SelfAttention/o_1/Assign" - input: "^encoder/block_000/layer_000/SelfAttention/q_1/Assign" - input: "^encoder/block_000/layer_000/SelfAttention/relative_attention_bias_1/Assign" - input: "^encoder/block_000/layer_000/SelfAttention/v_1/Assign" - input: "^encoder/block_000/layer_000/rms_norm/scale_1/Assign" - input: "^encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_1/Assign" - input: "^encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_1/Assign" - input: "^encoder/block_000/layer_001/DenseReluDense/wo/kernel_1/Assign" - input: "^encoder/block_000/layer_001/rms_norm/scale_1/Assign" - input: "^encoder/block_001/layer_000/SelfAttention/k_1/Assign" - input: "^encoder/block_001/layer_000/SelfAttention/o_1/Assign" - input: "^encoder/block_001/layer_000/SelfAttention/q_1/Assign" - input: "^encoder/block_001/layer_000/SelfAttention/v_1/Assign" - input: "^encoder/block_001/layer_000/rms_norm/scale_1/Assign" - input: "^encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_1/Assign" - input: "^encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_1/Assign" - input: "^encoder/block_001/layer_001/DenseReluDense/wo/kernel_1/Assign" - input: "^encoder/block_001/layer_001/rms_norm/scale_1/Assign" - input: "^encoder/rms_norm/scale_1/Assign" - input: "^shared/embedding_1/Assign" - device: "/device:CPU:0" -} -node { - name: "group_deps_1" - op: "NoOp" - input: "^group_deps_1/NoOp" - input: "^group_deps_1/NoOp_1" -} -node { - name: "loss/tags" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - } - string_val: "loss" - } - } - } -} -node { - name: "loss" - op: "ScalarSummary" - input: "loss/tags" - input: "Print" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "mtf_restore_hook/group_deps" - op: "NoOp" - input: "^decoder/block_000/layer_000/SelfAttention/k_1/group_deps" - input: "^decoder/block_000/layer_000/SelfAttention/k_slot_v_1/group_deps_1" - input: "^decoder/block_000/layer_000/SelfAttention/o_1/group_deps" - input: "^decoder/block_000/layer_000/SelfAttention/o_slot_v_1/group_deps_1" - input: "^decoder/block_000/layer_000/SelfAttention/q_1/group_deps" - input: "^decoder/block_000/layer_000/SelfAttention/q_slot_v_1/group_deps_1" - input: "^decoder/block_000/layer_000/SelfAttention/relative_attention_bias_1/group_deps" - input: "^decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slot_v_1/group_deps_1" - input: "^decoder/block_000/layer_000/SelfAttention/v_1/group_deps" - input: "^decoder/block_000/layer_000/SelfAttention/v_slot_v_1/group_deps_1" - input: "^decoder/block_000/layer_000/rms_norm/scale_1/group_deps" - input: "^decoder/block_000/layer_000/rms_norm/scale_slot_v_1/group_deps_1" - input: "^decoder/block_000/layer_001/EncDecAttention/k_1/group_deps" - input: "^decoder/block_000/layer_001/EncDecAttention/k_slot_v_1/group_deps_1" - input: "^decoder/block_000/layer_001/EncDecAttention/o_1/group_deps" - input: "^decoder/block_000/layer_001/EncDecAttention/o_slot_v_1/group_deps_1" - input: "^decoder/block_000/layer_001/EncDecAttention/q_1/group_deps" - input: "^decoder/block_000/layer_001/EncDecAttention/q_slot_v_1/group_deps_1" - input: "^decoder/block_000/layer_001/EncDecAttention/v_1/group_deps" - input: "^decoder/block_000/layer_001/EncDecAttention/v_slot_v_1/group_deps_1" - input: "^decoder/block_000/layer_001/rms_norm/scale_1/group_deps" - input: "^decoder/block_000/layer_001/rms_norm/scale_slot_v_1/group_deps_1" - input: "^decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_1/group_deps" - input: "^decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slot_v_1/group_deps_1" - input: "^decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_1/group_deps" - input: "^decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slot_v_1/group_deps_1" - input: "^decoder/block_000/layer_002/DenseReluDense/wo/kernel_1/group_deps" - input: "^decoder/block_000/layer_002/DenseReluDense/wo/kernel_slot_v_1/group_deps_1" - input: "^decoder/block_000/layer_002/rms_norm/scale_1/group_deps" - input: "^decoder/block_000/layer_002/rms_norm/scale_slot_v_1/group_deps_1" - input: "^decoder/block_001/layer_000/SelfAttention/k_1/group_deps" - input: "^decoder/block_001/layer_000/SelfAttention/k_slot_v_1/group_deps_1" - input: "^decoder/block_001/layer_000/SelfAttention/o_1/group_deps" - input: "^decoder/block_001/layer_000/SelfAttention/o_slot_v_1/group_deps_1" - input: "^decoder/block_001/layer_000/SelfAttention/q_1/group_deps" - input: "^decoder/block_001/layer_000/SelfAttention/q_slot_v_1/group_deps_1" - input: "^decoder/block_001/layer_000/SelfAttention/v_1/group_deps" - input: "^decoder/block_001/layer_000/SelfAttention/v_slot_v_1/group_deps_1" - input: "^decoder/block_001/layer_000/rms_norm/scale_1/group_deps" - input: "^decoder/block_001/layer_000/rms_norm/scale_slot_v_1/group_deps_1" - input: "^decoder/block_001/layer_001/EncDecAttention/k_1/group_deps" - input: "^decoder/block_001/layer_001/EncDecAttention/k_slot_v_1/group_deps_1" - input: "^decoder/block_001/layer_001/EncDecAttention/o_1/group_deps" - input: "^decoder/block_001/layer_001/EncDecAttention/o_slot_v_1/group_deps_1" - input: "^decoder/block_001/layer_001/EncDecAttention/q_1/group_deps" - input: "^decoder/block_001/layer_001/EncDecAttention/q_slot_v_1/group_deps_1" - input: "^decoder/block_001/layer_001/EncDecAttention/v_1/group_deps" - input: "^decoder/block_001/layer_001/EncDecAttention/v_slot_v_1/group_deps_1" - input: "^decoder/block_001/layer_001/rms_norm/scale_1/group_deps" - input: "^decoder/block_001/layer_001/rms_norm/scale_slot_v_1/group_deps_1" - input: "^decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_1/group_deps" - input: "^decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slot_v_1/group_deps_1" - input: "^decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_1/group_deps" - input: "^decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slot_v_1/group_deps_1" - input: "^decoder/block_001/layer_002/DenseReluDense/wo/kernel_1/group_deps" - input: "^decoder/block_001/layer_002/DenseReluDense/wo/kernel_slot_v_1/group_deps_1" - input: "^decoder/block_001/layer_002/rms_norm/scale_1/group_deps" - input: "^decoder/block_001/layer_002/rms_norm/scale_slot_v_1/group_deps_1" - input: "^decoder/logits/kernel_1/group_deps" - input: "^decoder/logits/kernel_slot_v_1/group_deps_1" - input: "^decoder/rms_norm/scale_1/group_deps" - input: "^decoder/rms_norm/scale_slot_v_1/group_deps_1" - input: "^encoder/block_000/layer_000/SelfAttention/k_1/group_deps" - input: "^encoder/block_000/layer_000/SelfAttention/k_slot_v_1/group_deps_1" - input: "^encoder/block_000/layer_000/SelfAttention/o_1/group_deps" - input: "^encoder/block_000/layer_000/SelfAttention/o_slot_v_1/group_deps_1" - input: "^encoder/block_000/layer_000/SelfAttention/q_1/group_deps" - input: "^encoder/block_000/layer_000/SelfAttention/q_slot_v_1/group_deps_1" - input: "^encoder/block_000/layer_000/SelfAttention/relative_attention_bias_1/group_deps" - input: "^encoder/block_000/layer_000/SelfAttention/relative_attention_bias_slot_v_1/group_deps_1" - input: "^encoder/block_000/layer_000/SelfAttention/v_1/group_deps" - input: "^encoder/block_000/layer_000/SelfAttention/v_slot_v_1/group_deps_1" - input: "^encoder/block_000/layer_000/rms_norm/scale_1/group_deps" - input: "^encoder/block_000/layer_000/rms_norm/scale_slot_v_1/group_deps_1" - input: "^encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_1/group_deps" - input: "^encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slot_v_1/group_deps_1" - input: "^encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_1/group_deps" - input: "^encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slot_v_1/group_deps_1" - input: "^encoder/block_000/layer_001/DenseReluDense/wo/kernel_1/group_deps" - input: "^encoder/block_000/layer_001/DenseReluDense/wo/kernel_slot_v_1/group_deps_1" - input: "^encoder/block_000/layer_001/rms_norm/scale_1/group_deps" - input: "^encoder/block_000/layer_001/rms_norm/scale_slot_v_1/group_deps_1" - input: "^encoder/block_001/layer_000/SelfAttention/k_1/group_deps" - input: "^encoder/block_001/layer_000/SelfAttention/k_slot_v_1/group_deps_1" - input: "^encoder/block_001/layer_000/SelfAttention/o_1/group_deps" - input: "^encoder/block_001/layer_000/SelfAttention/o_slot_v_1/group_deps_1" - input: "^encoder/block_001/layer_000/SelfAttention/q_1/group_deps" - input: "^encoder/block_001/layer_000/SelfAttention/q_slot_v_1/group_deps_1" - input: "^encoder/block_001/layer_000/SelfAttention/v_1/group_deps" - input: "^encoder/block_001/layer_000/SelfAttention/v_slot_v_1/group_deps_1" - input: "^encoder/block_001/layer_000/rms_norm/scale_1/group_deps" - input: "^encoder/block_001/layer_000/rms_norm/scale_slot_v_1/group_deps_1" - input: "^encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_1/group_deps" - input: "^encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slot_v_1/group_deps_1" - input: "^encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_1/group_deps" - input: "^encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slot_v_1/group_deps_1" - input: "^encoder/block_001/layer_001/DenseReluDense/wo/kernel_1/group_deps" - input: "^encoder/block_001/layer_001/DenseReluDense/wo/kernel_slot_v_1/group_deps_1" - input: "^encoder/block_001/layer_001/rms_norm/scale_1/group_deps" - input: "^encoder/block_001/layer_001/rms_norm/scale_slot_v_1/group_deps_1" - input: "^encoder/rms_norm/scale_1/group_deps" - input: "^encoder/rms_norm/scale_slot_v_1/group_deps_1" - input: "^shared/embedding_1/group_deps" - input: "^shared/embedding_slot_v_1/group_deps_1" -} -node { - name: "init/NoOp" - op: "NoOp" - input: "^global_step/Assign" -} -node { - name: "init/NoOp_1" - op: "NoOp" - input: "^decoder/block_000/layer_000/SelfAttention/k/Assign" - input: "^decoder/block_000/layer_000/SelfAttention/k_slot_v/Assign" - input: "^decoder/block_000/layer_000/SelfAttention/o/Assign" - input: "^decoder/block_000/layer_000/SelfAttention/o_slot_v/Assign" - input: "^decoder/block_000/layer_000/SelfAttention/q/Assign" - input: "^decoder/block_000/layer_000/SelfAttention/q_slot_v/Assign" - input: "^decoder/block_000/layer_000/SelfAttention/relative_attention_bias/Assign" - input: "^decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slot_v/Assign" - input: "^decoder/block_000/layer_000/SelfAttention/v/Assign" - input: "^decoder/block_000/layer_000/SelfAttention/v_slot_v/Assign" - input: "^decoder/block_000/layer_000/rms_norm/scale/Assign" - input: "^decoder/block_000/layer_000/rms_norm/scale_slot_v/Assign" - input: "^decoder/block_000/layer_001/EncDecAttention/k/Assign" - input: "^decoder/block_000/layer_001/EncDecAttention/k_slot_v/Assign" - input: "^decoder/block_000/layer_001/EncDecAttention/o/Assign" - input: "^decoder/block_000/layer_001/EncDecAttention/o_slot_v/Assign" - input: "^decoder/block_000/layer_001/EncDecAttention/q/Assign" - input: "^decoder/block_000/layer_001/EncDecAttention/q_slot_v/Assign" - input: "^decoder/block_000/layer_001/EncDecAttention/v/Assign" - input: "^decoder/block_000/layer_001/EncDecAttention/v_slot_v/Assign" - input: "^decoder/block_000/layer_001/rms_norm/scale/Assign" - input: "^decoder/block_000/layer_001/rms_norm/scale_slot_v/Assign" - input: "^decoder/block_000/layer_002/DenseReluDense/wi_0/kernel/Assign" - input: "^decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slot_v/Assign" - input: "^decoder/block_000/layer_002/DenseReluDense/wi_1/kernel/Assign" - input: "^decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slot_v/Assign" - input: "^decoder/block_000/layer_002/DenseReluDense/wo/kernel/Assign" - input: "^decoder/block_000/layer_002/DenseReluDense/wo/kernel_slot_v/Assign" - input: "^decoder/block_000/layer_002/rms_norm/scale/Assign" - input: "^decoder/block_000/layer_002/rms_norm/scale_slot_v/Assign" - input: "^decoder/block_001/layer_000/SelfAttention/k/Assign" - input: "^decoder/block_001/layer_000/SelfAttention/k_slot_v/Assign" - input: "^decoder/block_001/layer_000/SelfAttention/o/Assign" - input: "^decoder/block_001/layer_000/SelfAttention/o_slot_v/Assign" - input: "^decoder/block_001/layer_000/SelfAttention/q/Assign" - input: "^decoder/block_001/layer_000/SelfAttention/q_slot_v/Assign" - input: "^decoder/block_001/layer_000/SelfAttention/v/Assign" - input: "^decoder/block_001/layer_000/SelfAttention/v_slot_v/Assign" - input: "^decoder/block_001/layer_000/rms_norm/scale/Assign" - input: "^decoder/block_001/layer_000/rms_norm/scale_slot_v/Assign" - input: "^decoder/block_001/layer_001/EncDecAttention/k/Assign" - input: "^decoder/block_001/layer_001/EncDecAttention/k_slot_v/Assign" - input: "^decoder/block_001/layer_001/EncDecAttention/o/Assign" - input: "^decoder/block_001/layer_001/EncDecAttention/o_slot_v/Assign" - input: "^decoder/block_001/layer_001/EncDecAttention/q/Assign" - input: "^decoder/block_001/layer_001/EncDecAttention/q_slot_v/Assign" - input: "^decoder/block_001/layer_001/EncDecAttention/v/Assign" - input: "^decoder/block_001/layer_001/EncDecAttention/v_slot_v/Assign" - input: "^decoder/block_001/layer_001/rms_norm/scale/Assign" - input: "^decoder/block_001/layer_001/rms_norm/scale_slot_v/Assign" - input: "^decoder/block_001/layer_002/DenseReluDense/wi_0/kernel/Assign" - input: "^decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slot_v/Assign" - input: "^decoder/block_001/layer_002/DenseReluDense/wi_1/kernel/Assign" - input: "^decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slot_v/Assign" - input: "^decoder/block_001/layer_002/DenseReluDense/wo/kernel/Assign" - input: "^decoder/block_001/layer_002/DenseReluDense/wo/kernel_slot_v/Assign" - input: "^decoder/block_001/layer_002/rms_norm/scale/Assign" - input: "^decoder/block_001/layer_002/rms_norm/scale_slot_v/Assign" - input: "^decoder/logits/kernel/Assign" - input: "^decoder/logits/kernel_slot_v/Assign" - input: "^decoder/rms_norm/scale/Assign" - input: "^decoder/rms_norm/scale_slot_v/Assign" - input: "^encoder/block_000/layer_000/SelfAttention/k/Assign" - input: "^encoder/block_000/layer_000/SelfAttention/k_slot_v/Assign" - input: "^encoder/block_000/layer_000/SelfAttention/o/Assign" - input: "^encoder/block_000/layer_000/SelfAttention/o_slot_v/Assign" - input: "^encoder/block_000/layer_000/SelfAttention/q/Assign" - input: "^encoder/block_000/layer_000/SelfAttention/q_slot_v/Assign" - input: "^encoder/block_000/layer_000/SelfAttention/relative_attention_bias/Assign" - input: "^encoder/block_000/layer_000/SelfAttention/relative_attention_bias_slot_v/Assign" - input: "^encoder/block_000/layer_000/SelfAttention/v/Assign" - input: "^encoder/block_000/layer_000/SelfAttention/v_slot_v/Assign" - input: "^encoder/block_000/layer_000/rms_norm/scale/Assign" - input: "^encoder/block_000/layer_000/rms_norm/scale_slot_v/Assign" - input: "^encoder/block_000/layer_001/DenseReluDense/wi_0/kernel/Assign" - input: "^encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slot_v/Assign" - input: "^encoder/block_000/layer_001/DenseReluDense/wi_1/kernel/Assign" - input: "^encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slot_v/Assign" - input: "^encoder/block_000/layer_001/DenseReluDense/wo/kernel/Assign" - input: "^encoder/block_000/layer_001/DenseReluDense/wo/kernel_slot_v/Assign" - input: "^encoder/block_000/layer_001/rms_norm/scale/Assign" - input: "^encoder/block_000/layer_001/rms_norm/scale_slot_v/Assign" - input: "^encoder/block_001/layer_000/SelfAttention/k/Assign" - input: "^encoder/block_001/layer_000/SelfAttention/k_slot_v/Assign" - input: "^encoder/block_001/layer_000/SelfAttention/o/Assign" - input: "^encoder/block_001/layer_000/SelfAttention/o_slot_v/Assign" - input: "^encoder/block_001/layer_000/SelfAttention/q/Assign" - input: "^encoder/block_001/layer_000/SelfAttention/q_slot_v/Assign" - input: "^encoder/block_001/layer_000/SelfAttention/v/Assign" - input: "^encoder/block_001/layer_000/SelfAttention/v_slot_v/Assign" - input: "^encoder/block_001/layer_000/rms_norm/scale/Assign" - input: "^encoder/block_001/layer_000/rms_norm/scale_slot_v/Assign" - input: "^encoder/block_001/layer_001/DenseReluDense/wi_0/kernel/Assign" - input: "^encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slot_v/Assign" - input: "^encoder/block_001/layer_001/DenseReluDense/wi_1/kernel/Assign" - input: "^encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slot_v/Assign" - input: "^encoder/block_001/layer_001/DenseReluDense/wo/kernel/Assign" - input: "^encoder/block_001/layer_001/DenseReluDense/wo/kernel_slot_v/Assign" - input: "^encoder/block_001/layer_001/rms_norm/scale/Assign" - input: "^encoder/block_001/layer_001/rms_norm/scale_slot_v/Assign" - input: "^encoder/rms_norm/scale/Assign" - input: "^encoder/rms_norm/scale_slot_v/Assign" - input: "^shared/embedding/Assign" - input: "^shared/embedding_slot_v/Assign" - device: "/device:CPU:0" -} -node { - name: "init" - op: "NoOp" - input: "^init/NoOp" - input: "^init/NoOp_1" -} -node { - name: "init_1" - op: "NoOp" -} -node { - name: "group_deps_2" - op: "NoOp" - input: "^init" - input: "^init_1" -} -node { - name: "report_uninitialized_variables/VarIsInitializedOp" - op: "VarIsInitializedOp" - input: "global_step" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized" - op: "IsVariableInitialized" - input: "shared/embedding" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@shared/embedding" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_1" - op: "IsVariableInitialized" - input: "encoder/block_000/layer_000/rms_norm/scale" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_2" - op: "IsVariableInitialized" - input: "encoder/block_000/layer_000/SelfAttention/q" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_3" - op: "IsVariableInitialized" - input: "encoder/block_000/layer_000/SelfAttention/k" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_4" - op: "IsVariableInitialized" - input: "encoder/block_000/layer_000/SelfAttention/v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_5" - op: "IsVariableInitialized" - input: "encoder/block_000/layer_000/SelfAttention/o" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_6" - op: "IsVariableInitialized" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/relative_attention_bias" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_7" - op: "IsVariableInitialized" - input: "encoder/block_000/layer_001/rms_norm/scale" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_8" - op: "IsVariableInitialized" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wi_0/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_9" - op: "IsVariableInitialized" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wi_1/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_10" - op: "IsVariableInitialized" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wo/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_11" - op: "IsVariableInitialized" - input: "encoder/block_001/layer_000/rms_norm/scale" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_12" - op: "IsVariableInitialized" - input: "encoder/block_001/layer_000/SelfAttention/q" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_13" - op: "IsVariableInitialized" - input: "encoder/block_001/layer_000/SelfAttention/k" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_14" - op: "IsVariableInitialized" - input: "encoder/block_001/layer_000/SelfAttention/v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_15" - op: "IsVariableInitialized" - input: "encoder/block_001/layer_000/SelfAttention/o" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_16" - op: "IsVariableInitialized" - input: "encoder/block_001/layer_001/rms_norm/scale" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_17" - op: "IsVariableInitialized" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wi_0/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_18" - op: "IsVariableInitialized" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wi_1/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_19" - op: "IsVariableInitialized" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wo/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_20" - op: "IsVariableInitialized" - input: "encoder/rms_norm/scale" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_21" - op: "IsVariableInitialized" - input: "decoder/block_000/layer_000/rms_norm/scale" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_22" - op: "IsVariableInitialized" - input: "decoder/block_000/layer_000/SelfAttention/q" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_23" - op: "IsVariableInitialized" - input: "decoder/block_000/layer_000/SelfAttention/k" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_24" - op: "IsVariableInitialized" - input: "decoder/block_000/layer_000/SelfAttention/v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_25" - op: "IsVariableInitialized" - input: "decoder/block_000/layer_000/SelfAttention/o" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_26" - op: "IsVariableInitialized" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/relative_attention_bias" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_27" - op: "IsVariableInitialized" - input: "decoder/block_000/layer_001/rms_norm/scale" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_28" - op: "IsVariableInitialized" - input: "decoder/block_000/layer_001/EncDecAttention/q" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_29" - op: "IsVariableInitialized" - input: "decoder/block_000/layer_001/EncDecAttention/k" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_30" - op: "IsVariableInitialized" - input: "decoder/block_000/layer_001/EncDecAttention/v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_31" - op: "IsVariableInitialized" - input: "decoder/block_000/layer_001/EncDecAttention/o" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_32" - op: "IsVariableInitialized" - input: "decoder/block_000/layer_002/rms_norm/scale" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_33" - op: "IsVariableInitialized" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wi_0/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_34" - op: "IsVariableInitialized" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wi_1/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_35" - op: "IsVariableInitialized" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wo/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_36" - op: "IsVariableInitialized" - input: "decoder/block_001/layer_000/rms_norm/scale" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_37" - op: "IsVariableInitialized" - input: "decoder/block_001/layer_000/SelfAttention/q" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_38" - op: "IsVariableInitialized" - input: "decoder/block_001/layer_000/SelfAttention/k" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_39" - op: "IsVariableInitialized" - input: "decoder/block_001/layer_000/SelfAttention/v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_40" - op: "IsVariableInitialized" - input: "decoder/block_001/layer_000/SelfAttention/o" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_41" - op: "IsVariableInitialized" - input: "decoder/block_001/layer_001/rms_norm/scale" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_42" - op: "IsVariableInitialized" - input: "decoder/block_001/layer_001/EncDecAttention/q" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_43" - op: "IsVariableInitialized" - input: "decoder/block_001/layer_001/EncDecAttention/k" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_44" - op: "IsVariableInitialized" - input: "decoder/block_001/layer_001/EncDecAttention/v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_45" - op: "IsVariableInitialized" - input: "decoder/block_001/layer_001/EncDecAttention/o" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_46" - op: "IsVariableInitialized" - input: "decoder/block_001/layer_002/rms_norm/scale" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_47" - op: "IsVariableInitialized" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wi_0/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_48" - op: "IsVariableInitialized" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wi_1/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_49" - op: "IsVariableInitialized" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wo/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_50" - op: "IsVariableInitialized" - input: "decoder/rms_norm/scale" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_51" - op: "IsVariableInitialized" - input: "decoder/logits/kernel" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/logits/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_52" - op: "IsVariableInitialized" - input: "shared/embedding_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@shared/embedding_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_53" - op: "IsVariableInitialized" - input: "encoder/block_000/layer_000/rms_norm/scale_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_54" - op: "IsVariableInitialized" - input: "encoder/block_000/layer_000/SelfAttention/q_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/q_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_55" - op: "IsVariableInitialized" - input: "encoder/block_000/layer_000/SelfAttention/k_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/k_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_56" - op: "IsVariableInitialized" - input: "encoder/block_000/layer_000/SelfAttention/v_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/v_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_57" - op: "IsVariableInitialized" - input: "encoder/block_000/layer_000/SelfAttention/o_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/o_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_58" - op: "IsVariableInitialized" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/relative_attention_bias_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_59" - op: "IsVariableInitialized" - input: "encoder/block_000/layer_001/rms_norm/scale_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_60" - op: "IsVariableInitialized" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_61" - op: "IsVariableInitialized" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_62" - op: "IsVariableInitialized" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wo/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_63" - op: "IsVariableInitialized" - input: "encoder/block_001/layer_000/rms_norm/scale_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_64" - op: "IsVariableInitialized" - input: "encoder/block_001/layer_000/SelfAttention/q_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/q_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_65" - op: "IsVariableInitialized" - input: "encoder/block_001/layer_000/SelfAttention/k_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/k_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_66" - op: "IsVariableInitialized" - input: "encoder/block_001/layer_000/SelfAttention/v_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/v_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_67" - op: "IsVariableInitialized" - input: "encoder/block_001/layer_000/SelfAttention/o_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/o_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_68" - op: "IsVariableInitialized" - input: "encoder/block_001/layer_001/rms_norm/scale_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_69" - op: "IsVariableInitialized" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_70" - op: "IsVariableInitialized" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_71" - op: "IsVariableInitialized" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wo/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_72" - op: "IsVariableInitialized" - input: "encoder/rms_norm/scale_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_73" - op: "IsVariableInitialized" - input: "decoder/block_000/layer_000/rms_norm/scale_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_74" - op: "IsVariableInitialized" - input: "decoder/block_000/layer_000/SelfAttention/q_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/q_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_75" - op: "IsVariableInitialized" - input: "decoder/block_000/layer_000/SelfAttention/k_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/k_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_76" - op: "IsVariableInitialized" - input: "decoder/block_000/layer_000/SelfAttention/v_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/v_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_77" - op: "IsVariableInitialized" - input: "decoder/block_000/layer_000/SelfAttention/o_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/o_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_78" - op: "IsVariableInitialized" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_79" - op: "IsVariableInitialized" - input: "decoder/block_000/layer_001/rms_norm/scale_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_80" - op: "IsVariableInitialized" - input: "decoder/block_000/layer_001/EncDecAttention/q_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/q_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_81" - op: "IsVariableInitialized" - input: "decoder/block_000/layer_001/EncDecAttention/k_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/k_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_82" - op: "IsVariableInitialized" - input: "decoder/block_000/layer_001/EncDecAttention/v_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/v_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_83" - op: "IsVariableInitialized" - input: "decoder/block_000/layer_001/EncDecAttention/o_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/o_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_84" - op: "IsVariableInitialized" - input: "decoder/block_000/layer_002/rms_norm/scale_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_85" - op: "IsVariableInitialized" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_86" - op: "IsVariableInitialized" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_87" - op: "IsVariableInitialized" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wo/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_88" - op: "IsVariableInitialized" - input: "decoder/block_001/layer_000/rms_norm/scale_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_89" - op: "IsVariableInitialized" - input: "decoder/block_001/layer_000/SelfAttention/q_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/q_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_90" - op: "IsVariableInitialized" - input: "decoder/block_001/layer_000/SelfAttention/k_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/k_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_91" - op: "IsVariableInitialized" - input: "decoder/block_001/layer_000/SelfAttention/v_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/v_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_92" - op: "IsVariableInitialized" - input: "decoder/block_001/layer_000/SelfAttention/o_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/o_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_93" - op: "IsVariableInitialized" - input: "decoder/block_001/layer_001/rms_norm/scale_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_94" - op: "IsVariableInitialized" - input: "decoder/block_001/layer_001/EncDecAttention/q_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/q_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_95" - op: "IsVariableInitialized" - input: "decoder/block_001/layer_001/EncDecAttention/k_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/k_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_96" - op: "IsVariableInitialized" - input: "decoder/block_001/layer_001/EncDecAttention/v_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/v_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_97" - op: "IsVariableInitialized" - input: "decoder/block_001/layer_001/EncDecAttention/o_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/o_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_98" - op: "IsVariableInitialized" - input: "decoder/block_001/layer_002/rms_norm/scale_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_99" - op: "IsVariableInitialized" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_100" - op: "IsVariableInitialized" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_101" - op: "IsVariableInitialized" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wo/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_102" - op: "IsVariableInitialized" - input: "decoder/rms_norm/scale_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables/IsVariableInitialized_103" - op: "IsVariableInitialized" - input: "decoder/logits/kernel_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/logits/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables/stack" - op: "Pack" - input: "report_uninitialized_variables/VarIsInitializedOp" - input: "report_uninitialized_variables/IsVariableInitialized" - input: "report_uninitialized_variables/IsVariableInitialized_1" - input: "report_uninitialized_variables/IsVariableInitialized_2" - input: "report_uninitialized_variables/IsVariableInitialized_3" - input: "report_uninitialized_variables/IsVariableInitialized_4" - input: "report_uninitialized_variables/IsVariableInitialized_5" - input: "report_uninitialized_variables/IsVariableInitialized_6" - input: "report_uninitialized_variables/IsVariableInitialized_7" - input: "report_uninitialized_variables/IsVariableInitialized_8" - input: "report_uninitialized_variables/IsVariableInitialized_9" - input: "report_uninitialized_variables/IsVariableInitialized_10" - input: "report_uninitialized_variables/IsVariableInitialized_11" - input: "report_uninitialized_variables/IsVariableInitialized_12" - input: "report_uninitialized_variables/IsVariableInitialized_13" - input: "report_uninitialized_variables/IsVariableInitialized_14" - input: "report_uninitialized_variables/IsVariableInitialized_15" - input: "report_uninitialized_variables/IsVariableInitialized_16" - input: "report_uninitialized_variables/IsVariableInitialized_17" - input: "report_uninitialized_variables/IsVariableInitialized_18" - input: "report_uninitialized_variables/IsVariableInitialized_19" - input: "report_uninitialized_variables/IsVariableInitialized_20" - input: "report_uninitialized_variables/IsVariableInitialized_21" - input: "report_uninitialized_variables/IsVariableInitialized_22" - input: "report_uninitialized_variables/IsVariableInitialized_23" - input: "report_uninitialized_variables/IsVariableInitialized_24" - input: "report_uninitialized_variables/IsVariableInitialized_25" - input: "report_uninitialized_variables/IsVariableInitialized_26" - input: "report_uninitialized_variables/IsVariableInitialized_27" - input: "report_uninitialized_variables/IsVariableInitialized_28" - input: "report_uninitialized_variables/IsVariableInitialized_29" - input: "report_uninitialized_variables/IsVariableInitialized_30" - input: "report_uninitialized_variables/IsVariableInitialized_31" - input: "report_uninitialized_variables/IsVariableInitialized_32" - input: "report_uninitialized_variables/IsVariableInitialized_33" - input: "report_uninitialized_variables/IsVariableInitialized_34" - input: "report_uninitialized_variables/IsVariableInitialized_35" - input: "report_uninitialized_variables/IsVariableInitialized_36" - input: "report_uninitialized_variables/IsVariableInitialized_37" - input: "report_uninitialized_variables/IsVariableInitialized_38" - input: "report_uninitialized_variables/IsVariableInitialized_39" - input: "report_uninitialized_variables/IsVariableInitialized_40" - input: "report_uninitialized_variables/IsVariableInitialized_41" - input: "report_uninitialized_variables/IsVariableInitialized_42" - input: "report_uninitialized_variables/IsVariableInitialized_43" - input: "report_uninitialized_variables/IsVariableInitialized_44" - input: "report_uninitialized_variables/IsVariableInitialized_45" - input: "report_uninitialized_variables/IsVariableInitialized_46" - input: "report_uninitialized_variables/IsVariableInitialized_47" - input: "report_uninitialized_variables/IsVariableInitialized_48" - input: "report_uninitialized_variables/IsVariableInitialized_49" - input: "report_uninitialized_variables/IsVariableInitialized_50" - input: "report_uninitialized_variables/IsVariableInitialized_51" - input: "report_uninitialized_variables/IsVariableInitialized_52" - input: "report_uninitialized_variables/IsVariableInitialized_53" - input: "report_uninitialized_variables/IsVariableInitialized_54" - input: "report_uninitialized_variables/IsVariableInitialized_55" - input: "report_uninitialized_variables/IsVariableInitialized_56" - input: "report_uninitialized_variables/IsVariableInitialized_57" - input: "report_uninitialized_variables/IsVariableInitialized_58" - input: "report_uninitialized_variables/IsVariableInitialized_59" - input: "report_uninitialized_variables/IsVariableInitialized_60" - input: "report_uninitialized_variables/IsVariableInitialized_61" - input: "report_uninitialized_variables/IsVariableInitialized_62" - input: "report_uninitialized_variables/IsVariableInitialized_63" - input: "report_uninitialized_variables/IsVariableInitialized_64" - input: "report_uninitialized_variables/IsVariableInitialized_65" - input: "report_uninitialized_variables/IsVariableInitialized_66" - input: "report_uninitialized_variables/IsVariableInitialized_67" - input: "report_uninitialized_variables/IsVariableInitialized_68" - input: "report_uninitialized_variables/IsVariableInitialized_69" - input: "report_uninitialized_variables/IsVariableInitialized_70" - input: "report_uninitialized_variables/IsVariableInitialized_71" - input: "report_uninitialized_variables/IsVariableInitialized_72" - input: "report_uninitialized_variables/IsVariableInitialized_73" - input: "report_uninitialized_variables/IsVariableInitialized_74" - input: "report_uninitialized_variables/IsVariableInitialized_75" - input: "report_uninitialized_variables/IsVariableInitialized_76" - input: "report_uninitialized_variables/IsVariableInitialized_77" - input: "report_uninitialized_variables/IsVariableInitialized_78" - input: "report_uninitialized_variables/IsVariableInitialized_79" - input: "report_uninitialized_variables/IsVariableInitialized_80" - input: "report_uninitialized_variables/IsVariableInitialized_81" - input: "report_uninitialized_variables/IsVariableInitialized_82" - input: "report_uninitialized_variables/IsVariableInitialized_83" - input: "report_uninitialized_variables/IsVariableInitialized_84" - input: "report_uninitialized_variables/IsVariableInitialized_85" - input: "report_uninitialized_variables/IsVariableInitialized_86" - input: "report_uninitialized_variables/IsVariableInitialized_87" - input: "report_uninitialized_variables/IsVariableInitialized_88" - input: "report_uninitialized_variables/IsVariableInitialized_89" - input: "report_uninitialized_variables/IsVariableInitialized_90" - input: "report_uninitialized_variables/IsVariableInitialized_91" - input: "report_uninitialized_variables/IsVariableInitialized_92" - input: "report_uninitialized_variables/IsVariableInitialized_93" - input: "report_uninitialized_variables/IsVariableInitialized_94" - input: "report_uninitialized_variables/IsVariableInitialized_95" - input: "report_uninitialized_variables/IsVariableInitialized_96" - input: "report_uninitialized_variables/IsVariableInitialized_97" - input: "report_uninitialized_variables/IsVariableInitialized_98" - input: "report_uninitialized_variables/IsVariableInitialized_99" - input: "report_uninitialized_variables/IsVariableInitialized_100" - input: "report_uninitialized_variables/IsVariableInitialized_101" - input: "report_uninitialized_variables/IsVariableInitialized_102" - input: "report_uninitialized_variables/IsVariableInitialized_103" - device: "/device:CPU:0" - attr { - key: "N" - value { - i: 105 - } - } - attr { - key: "T" - value { - type: DT_BOOL - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 105 - } - } - } - } - } - attr { - key: "axis" - value { - i: 0 - } - } -} -node { - name: "report_uninitialized_variables/LogicalNot" - op: "LogicalNot" - input: "report_uninitialized_variables/stack" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 105 - } - } - } - } - } -} -node { - name: "report_uninitialized_variables/Const" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 105 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - dim { - size: 105 - } - } - string_val: "global_step" - string_val: "shared/embedding" - string_val: "encoder/block_000/layer_000/rms_norm/scale" - string_val: "encoder/block_000/layer_000/SelfAttention/q" - string_val: "encoder/block_000/layer_000/SelfAttention/k" - string_val: "encoder/block_000/layer_000/SelfAttention/v" - string_val: "encoder/block_000/layer_000/SelfAttention/o" - string_val: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias" - string_val: "encoder/block_000/layer_001/rms_norm/scale" - string_val: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel" - string_val: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel" - string_val: "encoder/block_000/layer_001/DenseReluDense/wo/kernel" - string_val: "encoder/block_001/layer_000/rms_norm/scale" - string_val: "encoder/block_001/layer_000/SelfAttention/q" - string_val: "encoder/block_001/layer_000/SelfAttention/k" - string_val: "encoder/block_001/layer_000/SelfAttention/v" - string_val: "encoder/block_001/layer_000/SelfAttention/o" - string_val: "encoder/block_001/layer_001/rms_norm/scale" - string_val: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel" - string_val: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel" - string_val: "encoder/block_001/layer_001/DenseReluDense/wo/kernel" - string_val: "encoder/rms_norm/scale" - string_val: "decoder/block_000/layer_000/rms_norm/scale" - string_val: "decoder/block_000/layer_000/SelfAttention/q" - string_val: "decoder/block_000/layer_000/SelfAttention/k" - string_val: "decoder/block_000/layer_000/SelfAttention/v" - string_val: "decoder/block_000/layer_000/SelfAttention/o" - string_val: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias" - string_val: "decoder/block_000/layer_001/rms_norm/scale" - string_val: "decoder/block_000/layer_001/EncDecAttention/q" - string_val: "decoder/block_000/layer_001/EncDecAttention/k" - string_val: "decoder/block_000/layer_001/EncDecAttention/v" - string_val: "decoder/block_000/layer_001/EncDecAttention/o" - string_val: "decoder/block_000/layer_002/rms_norm/scale" - string_val: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel" - string_val: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel" - string_val: "decoder/block_000/layer_002/DenseReluDense/wo/kernel" - string_val: "decoder/block_001/layer_000/rms_norm/scale" - string_val: "decoder/block_001/layer_000/SelfAttention/q" - string_val: "decoder/block_001/layer_000/SelfAttention/k" - string_val: "decoder/block_001/layer_000/SelfAttention/v" - string_val: "decoder/block_001/layer_000/SelfAttention/o" - string_val: "decoder/block_001/layer_001/rms_norm/scale" - string_val: "decoder/block_001/layer_001/EncDecAttention/q" - string_val: "decoder/block_001/layer_001/EncDecAttention/k" - string_val: "decoder/block_001/layer_001/EncDecAttention/v" - string_val: "decoder/block_001/layer_001/EncDecAttention/o" - string_val: "decoder/block_001/layer_002/rms_norm/scale" - string_val: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel" - string_val: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel" - string_val: "decoder/block_001/layer_002/DenseReluDense/wo/kernel" - string_val: "decoder/rms_norm/scale" - string_val: "decoder/logits/kernel" - string_val: "shared/embedding_slot_v" - string_val: "encoder/block_000/layer_000/rms_norm/scale_slot_v" - string_val: "encoder/block_000/layer_000/SelfAttention/q_slot_v" - string_val: "encoder/block_000/layer_000/SelfAttention/k_slot_v" - string_val: "encoder/block_000/layer_000/SelfAttention/v_slot_v" - string_val: "encoder/block_000/layer_000/SelfAttention/o_slot_v" - string_val: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias_slot_v" - string_val: "encoder/block_000/layer_001/rms_norm/scale_slot_v" - string_val: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slot_v" - string_val: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slot_v" - string_val: "encoder/block_000/layer_001/DenseReluDense/wo/kernel_slot_v" - string_val: "encoder/block_001/layer_000/rms_norm/scale_slot_v" - string_val: "encoder/block_001/layer_000/SelfAttention/q_slot_v" - string_val: "encoder/block_001/layer_000/SelfAttention/k_slot_v" - string_val: "encoder/block_001/layer_000/SelfAttention/v_slot_v" - string_val: "encoder/block_001/layer_000/SelfAttention/o_slot_v" - string_val: "encoder/block_001/layer_001/rms_norm/scale_slot_v" - string_val: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slot_v" - string_val: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slot_v" - string_val: "encoder/block_001/layer_001/DenseReluDense/wo/kernel_slot_v" - string_val: "encoder/rms_norm/scale_slot_v" - string_val: "decoder/block_000/layer_000/rms_norm/scale_slot_v" - string_val: "decoder/block_000/layer_000/SelfAttention/q_slot_v" - string_val: "decoder/block_000/layer_000/SelfAttention/k_slot_v" - string_val: "decoder/block_000/layer_000/SelfAttention/v_slot_v" - string_val: "decoder/block_000/layer_000/SelfAttention/o_slot_v" - string_val: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slot_v" - string_val: "decoder/block_000/layer_001/rms_norm/scale_slot_v" - string_val: "decoder/block_000/layer_001/EncDecAttention/q_slot_v" - string_val: "decoder/block_000/layer_001/EncDecAttention/k_slot_v" - string_val: "decoder/block_000/layer_001/EncDecAttention/v_slot_v" - string_val: "decoder/block_000/layer_001/EncDecAttention/o_slot_v" - string_val: "decoder/block_000/layer_002/rms_norm/scale_slot_v" - string_val: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slot_v" - string_val: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slot_v" - string_val: "decoder/block_000/layer_002/DenseReluDense/wo/kernel_slot_v" - string_val: "decoder/block_001/layer_000/rms_norm/scale_slot_v" - string_val: "decoder/block_001/layer_000/SelfAttention/q_slot_v" - string_val: "decoder/block_001/layer_000/SelfAttention/k_slot_v" - string_val: "decoder/block_001/layer_000/SelfAttention/v_slot_v" - string_val: "decoder/block_001/layer_000/SelfAttention/o_slot_v" - string_val: "decoder/block_001/layer_001/rms_norm/scale_slot_v" - string_val: "decoder/block_001/layer_001/EncDecAttention/q_slot_v" - string_val: "decoder/block_001/layer_001/EncDecAttention/k_slot_v" - string_val: "decoder/block_001/layer_001/EncDecAttention/v_slot_v" - string_val: "decoder/block_001/layer_001/EncDecAttention/o_slot_v" - string_val: "decoder/block_001/layer_002/rms_norm/scale_slot_v" - string_val: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slot_v" - string_val: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slot_v" - string_val: "decoder/block_001/layer_002/DenseReluDense/wo/kernel_slot_v" - string_val: "decoder/rms_norm/scale_slot_v" - string_val: "decoder/logits/kernel_slot_v" - } - } - } -} -node { - name: "report_uninitialized_variables/boolean_mask/Shape" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 105 - } - } - } -} -node { - name: "report_uninitialized_variables/boolean_mask/strided_slice/stack" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "report_uninitialized_variables/boolean_mask/strided_slice/stack_1" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } -} -node { - name: "report_uninitialized_variables/boolean_mask/strided_slice/stack_2" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } -} -node { - name: "report_uninitialized_variables/boolean_mask/strided_slice" - op: "StridedSlice" - input: "report_uninitialized_variables/boolean_mask/Shape" - input: "report_uninitialized_variables/boolean_mask/strided_slice/stack" - input: "report_uninitialized_variables/boolean_mask/strided_slice/stack_1" - input: "report_uninitialized_variables/boolean_mask/strided_slice/stack_2" - device: "/device:CPU:0" - attr { - key: "Index" - value { - type: DT_INT32 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "begin_mask" - value { - i: 0 - } - } - attr { - key: "ellipsis_mask" - value { - i: 0 - } - } - attr { - key: "end_mask" - value { - i: 0 - } - } - attr { - key: "new_axis_mask" - value { - i: 0 - } - } - attr { - key: "shrink_axis_mask" - value { - i: 0 - } - } -} -node { - name: "report_uninitialized_variables/boolean_mask/Prod/reduction_indices" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "report_uninitialized_variables/boolean_mask/Prod" - op: "Prod" - input: "report_uninitialized_variables/boolean_mask/strided_slice" - input: "report_uninitialized_variables/boolean_mask/Prod/reduction_indices" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "report_uninitialized_variables/boolean_mask/Shape_1" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 105 - } - } - } -} -node { - name: "report_uninitialized_variables/boolean_mask/strided_slice_1/stack" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "report_uninitialized_variables/boolean_mask/strided_slice_1/stack_1" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "report_uninitialized_variables/boolean_mask/strided_slice_1/stack_2" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } -} -node { - name: "report_uninitialized_variables/boolean_mask/strided_slice_1" - op: "StridedSlice" - input: "report_uninitialized_variables/boolean_mask/Shape_1" - input: "report_uninitialized_variables/boolean_mask/strided_slice_1/stack" - input: "report_uninitialized_variables/boolean_mask/strided_slice_1/stack_1" - input: "report_uninitialized_variables/boolean_mask/strided_slice_1/stack_2" - device: "/device:CPU:0" - attr { - key: "Index" - value { - type: DT_INT32 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - } - } - } - } - } - attr { - key: "begin_mask" - value { - i: 1 - } - } - attr { - key: "ellipsis_mask" - value { - i: 0 - } - } - attr { - key: "end_mask" - value { - i: 0 - } - } - attr { - key: "new_axis_mask" - value { - i: 0 - } - } - attr { - key: "shrink_axis_mask" - value { - i: 0 - } - } -} -node { - name: "report_uninitialized_variables/boolean_mask/Shape_2" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 105 - } - } - } -} -node { - name: "report_uninitialized_variables/boolean_mask/strided_slice_2/stack" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } -} -node { - name: "report_uninitialized_variables/boolean_mask/strided_slice_2/stack_1" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "report_uninitialized_variables/boolean_mask/strided_slice_2/stack_2" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } -} -node { - name: "report_uninitialized_variables/boolean_mask/strided_slice_2" - op: "StridedSlice" - input: "report_uninitialized_variables/boolean_mask/Shape_2" - input: "report_uninitialized_variables/boolean_mask/strided_slice_2/stack" - input: "report_uninitialized_variables/boolean_mask/strided_slice_2/stack_1" - input: "report_uninitialized_variables/boolean_mask/strided_slice_2/stack_2" - device: "/device:CPU:0" - attr { - key: "Index" - value { - type: DT_INT32 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - } - } - } - } - } - attr { - key: "begin_mask" - value { - i: 0 - } - } - attr { - key: "ellipsis_mask" - value { - i: 0 - } - } - attr { - key: "end_mask" - value { - i: 1 - } - } - attr { - key: "new_axis_mask" - value { - i: 0 - } - } - attr { - key: "shrink_axis_mask" - value { - i: 0 - } - } -} -node { - name: "report_uninitialized_variables/boolean_mask/concat/values_1" - op: "Pack" - input: "report_uninitialized_variables/boolean_mask/Prod" - device: "/device:CPU:0" - attr { - key: "N" - value { - i: 1 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "axis" - value { - i: 0 - } - } -} -node { - name: "report_uninitialized_variables/boolean_mask/concat/axis" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "report_uninitialized_variables/boolean_mask/concat" - op: "ConcatV2" - input: "report_uninitialized_variables/boolean_mask/strided_slice_1" - input: "report_uninitialized_variables/boolean_mask/concat/values_1" - input: "report_uninitialized_variables/boolean_mask/strided_slice_2" - input: "report_uninitialized_variables/boolean_mask/concat/axis" - device: "/device:CPU:0" - attr { - key: "N" - value { - i: 3 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } -} -node { - name: "report_uninitialized_variables/boolean_mask/Reshape" - op: "Reshape" - input: "report_uninitialized_variables/Const" - input: "report_uninitialized_variables/boolean_mask/concat" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_STRING - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 105 - } - } - } - } - } -} -node { - name: "report_uninitialized_variables/boolean_mask/Reshape_1/shape" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: -1 - } - } - } -} -node { - name: "report_uninitialized_variables/boolean_mask/Reshape_1" - op: "Reshape" - input: "report_uninitialized_variables/LogicalNot" - input: "report_uninitialized_variables/boolean_mask/Reshape_1/shape" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BOOL - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 105 - } - } - } - } - } -} -node { - name: "report_uninitialized_variables/boolean_mask/Where" - op: "Where" - input: "report_uninitialized_variables/boolean_mask/Reshape_1" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BOOL - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "report_uninitialized_variables/boolean_mask/Squeeze" - op: "Squeeze" - input: "report_uninitialized_variables/boolean_mask/Where" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_INT64 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - } - } - } - attr { - key: "squeeze_dims" - value { - list { - i: 1 - } - } - } -} -node { - name: "report_uninitialized_variables/boolean_mask/GatherV2/axis" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "report_uninitialized_variables/boolean_mask/GatherV2" - op: "GatherV2" - input: "report_uninitialized_variables/boolean_mask/Reshape" - input: "report_uninitialized_variables/boolean_mask/Squeeze" - input: "report_uninitialized_variables/boolean_mask/GatherV2/axis" - device: "/device:CPU:0" - attr { - key: "Taxis" - value { - type: DT_INT32 - } - } - attr { - key: "Tindices" - value { - type: DT_INT64 - } - } - attr { - key: "Tparams" - value { - type: DT_STRING - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - } - } - } - attr { - key: "batch_dims" - value { - i: 0 - } - } -} -node { - name: "report_uninitialized_resources/Const" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - dim { - } - } - } - } - } -} -node { - name: "concat/axis" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "concat" - op: "ConcatV2" - input: "report_uninitialized_variables/boolean_mask/GatherV2" - input: "report_uninitialized_resources/Const" - input: "concat/axis" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_STRING - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - } - } - } -} -node { - name: "report_uninitialized_variables_1/VarIsInitializedOp" - op: "VarIsInitializedOp" - input: "global_step" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized" - op: "IsVariableInitialized" - input: "shared/embedding" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@shared/embedding" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_1" - op: "IsVariableInitialized" - input: "encoder/block_000/layer_000/rms_norm/scale" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_2" - op: "IsVariableInitialized" - input: "encoder/block_000/layer_000/SelfAttention/q" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_3" - op: "IsVariableInitialized" - input: "encoder/block_000/layer_000/SelfAttention/k" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_4" - op: "IsVariableInitialized" - input: "encoder/block_000/layer_000/SelfAttention/v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_5" - op: "IsVariableInitialized" - input: "encoder/block_000/layer_000/SelfAttention/o" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_6" - op: "IsVariableInitialized" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/relative_attention_bias" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_7" - op: "IsVariableInitialized" - input: "encoder/block_000/layer_001/rms_norm/scale" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_8" - op: "IsVariableInitialized" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wi_0/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_9" - op: "IsVariableInitialized" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wi_1/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_10" - op: "IsVariableInitialized" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wo/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_11" - op: "IsVariableInitialized" - input: "encoder/block_001/layer_000/rms_norm/scale" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_12" - op: "IsVariableInitialized" - input: "encoder/block_001/layer_000/SelfAttention/q" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_13" - op: "IsVariableInitialized" - input: "encoder/block_001/layer_000/SelfAttention/k" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_14" - op: "IsVariableInitialized" - input: "encoder/block_001/layer_000/SelfAttention/v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_15" - op: "IsVariableInitialized" - input: "encoder/block_001/layer_000/SelfAttention/o" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_16" - op: "IsVariableInitialized" - input: "encoder/block_001/layer_001/rms_norm/scale" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_17" - op: "IsVariableInitialized" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wi_0/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_18" - op: "IsVariableInitialized" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wi_1/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_19" - op: "IsVariableInitialized" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wo/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_20" - op: "IsVariableInitialized" - input: "encoder/rms_norm/scale" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_21" - op: "IsVariableInitialized" - input: "decoder/block_000/layer_000/rms_norm/scale" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_22" - op: "IsVariableInitialized" - input: "decoder/block_000/layer_000/SelfAttention/q" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_23" - op: "IsVariableInitialized" - input: "decoder/block_000/layer_000/SelfAttention/k" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_24" - op: "IsVariableInitialized" - input: "decoder/block_000/layer_000/SelfAttention/v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_25" - op: "IsVariableInitialized" - input: "decoder/block_000/layer_000/SelfAttention/o" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_26" - op: "IsVariableInitialized" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/relative_attention_bias" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_27" - op: "IsVariableInitialized" - input: "decoder/block_000/layer_001/rms_norm/scale" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_28" - op: "IsVariableInitialized" - input: "decoder/block_000/layer_001/EncDecAttention/q" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_29" - op: "IsVariableInitialized" - input: "decoder/block_000/layer_001/EncDecAttention/k" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_30" - op: "IsVariableInitialized" - input: "decoder/block_000/layer_001/EncDecAttention/v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_31" - op: "IsVariableInitialized" - input: "decoder/block_000/layer_001/EncDecAttention/o" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_32" - op: "IsVariableInitialized" - input: "decoder/block_000/layer_002/rms_norm/scale" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_33" - op: "IsVariableInitialized" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wi_0/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_34" - op: "IsVariableInitialized" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wi_1/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_35" - op: "IsVariableInitialized" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wo/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_36" - op: "IsVariableInitialized" - input: "decoder/block_001/layer_000/rms_norm/scale" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_37" - op: "IsVariableInitialized" - input: "decoder/block_001/layer_000/SelfAttention/q" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_38" - op: "IsVariableInitialized" - input: "decoder/block_001/layer_000/SelfAttention/k" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_39" - op: "IsVariableInitialized" - input: "decoder/block_001/layer_000/SelfAttention/v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_40" - op: "IsVariableInitialized" - input: "decoder/block_001/layer_000/SelfAttention/o" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_41" - op: "IsVariableInitialized" - input: "decoder/block_001/layer_001/rms_norm/scale" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_42" - op: "IsVariableInitialized" - input: "decoder/block_001/layer_001/EncDecAttention/q" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/q" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_43" - op: "IsVariableInitialized" - input: "decoder/block_001/layer_001/EncDecAttention/k" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/k" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_44" - op: "IsVariableInitialized" - input: "decoder/block_001/layer_001/EncDecAttention/v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_45" - op: "IsVariableInitialized" - input: "decoder/block_001/layer_001/EncDecAttention/o" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/o" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_46" - op: "IsVariableInitialized" - input: "decoder/block_001/layer_002/rms_norm/scale" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_47" - op: "IsVariableInitialized" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wi_0/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_48" - op: "IsVariableInitialized" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wi_1/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_49" - op: "IsVariableInitialized" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wo/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_50" - op: "IsVariableInitialized" - input: "decoder/rms_norm/scale" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/rms_norm/scale" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_51" - op: "IsVariableInitialized" - input: "decoder/logits/kernel" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/logits/kernel" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_52" - op: "IsVariableInitialized" - input: "shared/embedding_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@shared/embedding_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_53" - op: "IsVariableInitialized" - input: "encoder/block_000/layer_000/rms_norm/scale_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_54" - op: "IsVariableInitialized" - input: "encoder/block_000/layer_000/SelfAttention/q_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/q_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_55" - op: "IsVariableInitialized" - input: "encoder/block_000/layer_000/SelfAttention/k_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/k_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_56" - op: "IsVariableInitialized" - input: "encoder/block_000/layer_000/SelfAttention/v_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/v_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_57" - op: "IsVariableInitialized" - input: "encoder/block_000/layer_000/SelfAttention/o_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/o_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_58" - op: "IsVariableInitialized" - input: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_000/SelfAttention/relative_attention_bias_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_59" - op: "IsVariableInitialized" - input: "encoder/block_000/layer_001/rms_norm/scale_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_60" - op: "IsVariableInitialized" - input: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_61" - op: "IsVariableInitialized" - input: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_62" - op: "IsVariableInitialized" - input: "encoder/block_000/layer_001/DenseReluDense/wo/kernel_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_000/layer_001/DenseReluDense/wo/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_63" - op: "IsVariableInitialized" - input: "encoder/block_001/layer_000/rms_norm/scale_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_64" - op: "IsVariableInitialized" - input: "encoder/block_001/layer_000/SelfAttention/q_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/q_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_65" - op: "IsVariableInitialized" - input: "encoder/block_001/layer_000/SelfAttention/k_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/k_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_66" - op: "IsVariableInitialized" - input: "encoder/block_001/layer_000/SelfAttention/v_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/v_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_67" - op: "IsVariableInitialized" - input: "encoder/block_001/layer_000/SelfAttention/o_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_000/SelfAttention/o_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_68" - op: "IsVariableInitialized" - input: "encoder/block_001/layer_001/rms_norm/scale_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_69" - op: "IsVariableInitialized" - input: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_70" - op: "IsVariableInitialized" - input: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_71" - op: "IsVariableInitialized" - input: "encoder/block_001/layer_001/DenseReluDense/wo/kernel_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/block_001/layer_001/DenseReluDense/wo/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_72" - op: "IsVariableInitialized" - input: "encoder/rms_norm/scale_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@encoder/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_73" - op: "IsVariableInitialized" - input: "decoder/block_000/layer_000/rms_norm/scale_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_74" - op: "IsVariableInitialized" - input: "decoder/block_000/layer_000/SelfAttention/q_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/q_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_75" - op: "IsVariableInitialized" - input: "decoder/block_000/layer_000/SelfAttention/k_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/k_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_76" - op: "IsVariableInitialized" - input: "decoder/block_000/layer_000/SelfAttention/v_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/v_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_77" - op: "IsVariableInitialized" - input: "decoder/block_000/layer_000/SelfAttention/o_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/o_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_78" - op: "IsVariableInitialized" - input: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_79" - op: "IsVariableInitialized" - input: "decoder/block_000/layer_001/rms_norm/scale_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_80" - op: "IsVariableInitialized" - input: "decoder/block_000/layer_001/EncDecAttention/q_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/q_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_81" - op: "IsVariableInitialized" - input: "decoder/block_000/layer_001/EncDecAttention/k_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/k_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_82" - op: "IsVariableInitialized" - input: "decoder/block_000/layer_001/EncDecAttention/v_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/v_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_83" - op: "IsVariableInitialized" - input: "decoder/block_000/layer_001/EncDecAttention/o_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_001/EncDecAttention/o_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_84" - op: "IsVariableInitialized" - input: "decoder/block_000/layer_002/rms_norm/scale_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_85" - op: "IsVariableInitialized" - input: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_86" - op: "IsVariableInitialized" - input: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_87" - op: "IsVariableInitialized" - input: "decoder/block_000/layer_002/DenseReluDense/wo/kernel_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_000/layer_002/DenseReluDense/wo/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_88" - op: "IsVariableInitialized" - input: "decoder/block_001/layer_000/rms_norm/scale_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_89" - op: "IsVariableInitialized" - input: "decoder/block_001/layer_000/SelfAttention/q_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/q_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_90" - op: "IsVariableInitialized" - input: "decoder/block_001/layer_000/SelfAttention/k_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/k_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_91" - op: "IsVariableInitialized" - input: "decoder/block_001/layer_000/SelfAttention/v_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/v_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_92" - op: "IsVariableInitialized" - input: "decoder/block_001/layer_000/SelfAttention/o_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_000/SelfAttention/o_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_93" - op: "IsVariableInitialized" - input: "decoder/block_001/layer_001/rms_norm/scale_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_94" - op: "IsVariableInitialized" - input: "decoder/block_001/layer_001/EncDecAttention/q_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/q_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_95" - op: "IsVariableInitialized" - input: "decoder/block_001/layer_001/EncDecAttention/k_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/k_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_96" - op: "IsVariableInitialized" - input: "decoder/block_001/layer_001/EncDecAttention/v_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/v_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_97" - op: "IsVariableInitialized" - input: "decoder/block_001/layer_001/EncDecAttention/o_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_001/EncDecAttention/o_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_98" - op: "IsVariableInitialized" - input: "decoder/block_001/layer_002/rms_norm/scale_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_99" - op: "IsVariableInitialized" - input: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_100" - op: "IsVariableInitialized" - input: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_101" - op: "IsVariableInitialized" - input: "decoder/block_001/layer_002/DenseReluDense/wo/kernel_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/block_001/layer_002/DenseReluDense/wo/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_102" - op: "IsVariableInitialized" - input: "decoder/rms_norm/scale_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/rms_norm/scale_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables_1/IsVariableInitialized_103" - op: "IsVariableInitialized" - input: "decoder/logits/kernel_slot_v" - device: "/device:CPU:0" - attr { - key: "_class" - value { - list { - s: "loc:@decoder/logits/kernel_slot_v" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } -} -node { - name: "report_uninitialized_variables_1/stack" - op: "Pack" - input: "report_uninitialized_variables_1/VarIsInitializedOp" - input: "report_uninitialized_variables_1/IsVariableInitialized" - input: "report_uninitialized_variables_1/IsVariableInitialized_1" - input: "report_uninitialized_variables_1/IsVariableInitialized_2" - input: "report_uninitialized_variables_1/IsVariableInitialized_3" - input: "report_uninitialized_variables_1/IsVariableInitialized_4" - input: "report_uninitialized_variables_1/IsVariableInitialized_5" - input: "report_uninitialized_variables_1/IsVariableInitialized_6" - input: "report_uninitialized_variables_1/IsVariableInitialized_7" - input: "report_uninitialized_variables_1/IsVariableInitialized_8" - input: "report_uninitialized_variables_1/IsVariableInitialized_9" - input: "report_uninitialized_variables_1/IsVariableInitialized_10" - input: "report_uninitialized_variables_1/IsVariableInitialized_11" - input: "report_uninitialized_variables_1/IsVariableInitialized_12" - input: "report_uninitialized_variables_1/IsVariableInitialized_13" - input: "report_uninitialized_variables_1/IsVariableInitialized_14" - input: "report_uninitialized_variables_1/IsVariableInitialized_15" - input: "report_uninitialized_variables_1/IsVariableInitialized_16" - input: "report_uninitialized_variables_1/IsVariableInitialized_17" - input: "report_uninitialized_variables_1/IsVariableInitialized_18" - input: "report_uninitialized_variables_1/IsVariableInitialized_19" - input: "report_uninitialized_variables_1/IsVariableInitialized_20" - input: "report_uninitialized_variables_1/IsVariableInitialized_21" - input: "report_uninitialized_variables_1/IsVariableInitialized_22" - input: "report_uninitialized_variables_1/IsVariableInitialized_23" - input: "report_uninitialized_variables_1/IsVariableInitialized_24" - input: "report_uninitialized_variables_1/IsVariableInitialized_25" - input: "report_uninitialized_variables_1/IsVariableInitialized_26" - input: "report_uninitialized_variables_1/IsVariableInitialized_27" - input: "report_uninitialized_variables_1/IsVariableInitialized_28" - input: "report_uninitialized_variables_1/IsVariableInitialized_29" - input: "report_uninitialized_variables_1/IsVariableInitialized_30" - input: "report_uninitialized_variables_1/IsVariableInitialized_31" - input: "report_uninitialized_variables_1/IsVariableInitialized_32" - input: "report_uninitialized_variables_1/IsVariableInitialized_33" - input: "report_uninitialized_variables_1/IsVariableInitialized_34" - input: "report_uninitialized_variables_1/IsVariableInitialized_35" - input: "report_uninitialized_variables_1/IsVariableInitialized_36" - input: "report_uninitialized_variables_1/IsVariableInitialized_37" - input: "report_uninitialized_variables_1/IsVariableInitialized_38" - input: "report_uninitialized_variables_1/IsVariableInitialized_39" - input: "report_uninitialized_variables_1/IsVariableInitialized_40" - input: "report_uninitialized_variables_1/IsVariableInitialized_41" - input: "report_uninitialized_variables_1/IsVariableInitialized_42" - input: "report_uninitialized_variables_1/IsVariableInitialized_43" - input: "report_uninitialized_variables_1/IsVariableInitialized_44" - input: "report_uninitialized_variables_1/IsVariableInitialized_45" - input: "report_uninitialized_variables_1/IsVariableInitialized_46" - input: "report_uninitialized_variables_1/IsVariableInitialized_47" - input: "report_uninitialized_variables_1/IsVariableInitialized_48" - input: "report_uninitialized_variables_1/IsVariableInitialized_49" - input: "report_uninitialized_variables_1/IsVariableInitialized_50" - input: "report_uninitialized_variables_1/IsVariableInitialized_51" - input: "report_uninitialized_variables_1/IsVariableInitialized_52" - input: "report_uninitialized_variables_1/IsVariableInitialized_53" - input: "report_uninitialized_variables_1/IsVariableInitialized_54" - input: "report_uninitialized_variables_1/IsVariableInitialized_55" - input: "report_uninitialized_variables_1/IsVariableInitialized_56" - input: "report_uninitialized_variables_1/IsVariableInitialized_57" - input: "report_uninitialized_variables_1/IsVariableInitialized_58" - input: "report_uninitialized_variables_1/IsVariableInitialized_59" - input: "report_uninitialized_variables_1/IsVariableInitialized_60" - input: "report_uninitialized_variables_1/IsVariableInitialized_61" - input: "report_uninitialized_variables_1/IsVariableInitialized_62" - input: "report_uninitialized_variables_1/IsVariableInitialized_63" - input: "report_uninitialized_variables_1/IsVariableInitialized_64" - input: "report_uninitialized_variables_1/IsVariableInitialized_65" - input: "report_uninitialized_variables_1/IsVariableInitialized_66" - input: "report_uninitialized_variables_1/IsVariableInitialized_67" - input: "report_uninitialized_variables_1/IsVariableInitialized_68" - input: "report_uninitialized_variables_1/IsVariableInitialized_69" - input: "report_uninitialized_variables_1/IsVariableInitialized_70" - input: "report_uninitialized_variables_1/IsVariableInitialized_71" - input: "report_uninitialized_variables_1/IsVariableInitialized_72" - input: "report_uninitialized_variables_1/IsVariableInitialized_73" - input: "report_uninitialized_variables_1/IsVariableInitialized_74" - input: "report_uninitialized_variables_1/IsVariableInitialized_75" - input: "report_uninitialized_variables_1/IsVariableInitialized_76" - input: "report_uninitialized_variables_1/IsVariableInitialized_77" - input: "report_uninitialized_variables_1/IsVariableInitialized_78" - input: "report_uninitialized_variables_1/IsVariableInitialized_79" - input: "report_uninitialized_variables_1/IsVariableInitialized_80" - input: "report_uninitialized_variables_1/IsVariableInitialized_81" - input: "report_uninitialized_variables_1/IsVariableInitialized_82" - input: "report_uninitialized_variables_1/IsVariableInitialized_83" - input: "report_uninitialized_variables_1/IsVariableInitialized_84" - input: "report_uninitialized_variables_1/IsVariableInitialized_85" - input: "report_uninitialized_variables_1/IsVariableInitialized_86" - input: "report_uninitialized_variables_1/IsVariableInitialized_87" - input: "report_uninitialized_variables_1/IsVariableInitialized_88" - input: "report_uninitialized_variables_1/IsVariableInitialized_89" - input: "report_uninitialized_variables_1/IsVariableInitialized_90" - input: "report_uninitialized_variables_1/IsVariableInitialized_91" - input: "report_uninitialized_variables_1/IsVariableInitialized_92" - input: "report_uninitialized_variables_1/IsVariableInitialized_93" - input: "report_uninitialized_variables_1/IsVariableInitialized_94" - input: "report_uninitialized_variables_1/IsVariableInitialized_95" - input: "report_uninitialized_variables_1/IsVariableInitialized_96" - input: "report_uninitialized_variables_1/IsVariableInitialized_97" - input: "report_uninitialized_variables_1/IsVariableInitialized_98" - input: "report_uninitialized_variables_1/IsVariableInitialized_99" - input: "report_uninitialized_variables_1/IsVariableInitialized_100" - input: "report_uninitialized_variables_1/IsVariableInitialized_101" - input: "report_uninitialized_variables_1/IsVariableInitialized_102" - input: "report_uninitialized_variables_1/IsVariableInitialized_103" - device: "/device:CPU:0" - attr { - key: "N" - value { - i: 105 - } - } - attr { - key: "T" - value { - type: DT_BOOL - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 105 - } - } - } - } - } - attr { - key: "axis" - value { - i: 0 - } - } -} -node { - name: "report_uninitialized_variables_1/LogicalNot" - op: "LogicalNot" - input: "report_uninitialized_variables_1/stack" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 105 - } - } - } - } - } -} -node { - name: "report_uninitialized_variables_1/Const" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 105 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - dim { - size: 105 - } - } - string_val: "global_step" - string_val: "shared/embedding" - string_val: "encoder/block_000/layer_000/rms_norm/scale" - string_val: "encoder/block_000/layer_000/SelfAttention/q" - string_val: "encoder/block_000/layer_000/SelfAttention/k" - string_val: "encoder/block_000/layer_000/SelfAttention/v" - string_val: "encoder/block_000/layer_000/SelfAttention/o" - string_val: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias" - string_val: "encoder/block_000/layer_001/rms_norm/scale" - string_val: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel" - string_val: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel" - string_val: "encoder/block_000/layer_001/DenseReluDense/wo/kernel" - string_val: "encoder/block_001/layer_000/rms_norm/scale" - string_val: "encoder/block_001/layer_000/SelfAttention/q" - string_val: "encoder/block_001/layer_000/SelfAttention/k" - string_val: "encoder/block_001/layer_000/SelfAttention/v" - string_val: "encoder/block_001/layer_000/SelfAttention/o" - string_val: "encoder/block_001/layer_001/rms_norm/scale" - string_val: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel" - string_val: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel" - string_val: "encoder/block_001/layer_001/DenseReluDense/wo/kernel" - string_val: "encoder/rms_norm/scale" - string_val: "decoder/block_000/layer_000/rms_norm/scale" - string_val: "decoder/block_000/layer_000/SelfAttention/q" - string_val: "decoder/block_000/layer_000/SelfAttention/k" - string_val: "decoder/block_000/layer_000/SelfAttention/v" - string_val: "decoder/block_000/layer_000/SelfAttention/o" - string_val: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias" - string_val: "decoder/block_000/layer_001/rms_norm/scale" - string_val: "decoder/block_000/layer_001/EncDecAttention/q" - string_val: "decoder/block_000/layer_001/EncDecAttention/k" - string_val: "decoder/block_000/layer_001/EncDecAttention/v" - string_val: "decoder/block_000/layer_001/EncDecAttention/o" - string_val: "decoder/block_000/layer_002/rms_norm/scale" - string_val: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel" - string_val: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel" - string_val: "decoder/block_000/layer_002/DenseReluDense/wo/kernel" - string_val: "decoder/block_001/layer_000/rms_norm/scale" - string_val: "decoder/block_001/layer_000/SelfAttention/q" - string_val: "decoder/block_001/layer_000/SelfAttention/k" - string_val: "decoder/block_001/layer_000/SelfAttention/v" - string_val: "decoder/block_001/layer_000/SelfAttention/o" - string_val: "decoder/block_001/layer_001/rms_norm/scale" - string_val: "decoder/block_001/layer_001/EncDecAttention/q" - string_val: "decoder/block_001/layer_001/EncDecAttention/k" - string_val: "decoder/block_001/layer_001/EncDecAttention/v" - string_val: "decoder/block_001/layer_001/EncDecAttention/o" - string_val: "decoder/block_001/layer_002/rms_norm/scale" - string_val: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel" - string_val: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel" - string_val: "decoder/block_001/layer_002/DenseReluDense/wo/kernel" - string_val: "decoder/rms_norm/scale" - string_val: "decoder/logits/kernel" - string_val: "shared/embedding_slot_v" - string_val: "encoder/block_000/layer_000/rms_norm/scale_slot_v" - string_val: "encoder/block_000/layer_000/SelfAttention/q_slot_v" - string_val: "encoder/block_000/layer_000/SelfAttention/k_slot_v" - string_val: "encoder/block_000/layer_000/SelfAttention/v_slot_v" - string_val: "encoder/block_000/layer_000/SelfAttention/o_slot_v" - string_val: "encoder/block_000/layer_000/SelfAttention/relative_attention_bias_slot_v" - string_val: "encoder/block_000/layer_001/rms_norm/scale_slot_v" - string_val: "encoder/block_000/layer_001/DenseReluDense/wi_0/kernel_slot_v" - string_val: "encoder/block_000/layer_001/DenseReluDense/wi_1/kernel_slot_v" - string_val: "encoder/block_000/layer_001/DenseReluDense/wo/kernel_slot_v" - string_val: "encoder/block_001/layer_000/rms_norm/scale_slot_v" - string_val: "encoder/block_001/layer_000/SelfAttention/q_slot_v" - string_val: "encoder/block_001/layer_000/SelfAttention/k_slot_v" - string_val: "encoder/block_001/layer_000/SelfAttention/v_slot_v" - string_val: "encoder/block_001/layer_000/SelfAttention/o_slot_v" - string_val: "encoder/block_001/layer_001/rms_norm/scale_slot_v" - string_val: "encoder/block_001/layer_001/DenseReluDense/wi_0/kernel_slot_v" - string_val: "encoder/block_001/layer_001/DenseReluDense/wi_1/kernel_slot_v" - string_val: "encoder/block_001/layer_001/DenseReluDense/wo/kernel_slot_v" - string_val: "encoder/rms_norm/scale_slot_v" - string_val: "decoder/block_000/layer_000/rms_norm/scale_slot_v" - string_val: "decoder/block_000/layer_000/SelfAttention/q_slot_v" - string_val: "decoder/block_000/layer_000/SelfAttention/k_slot_v" - string_val: "decoder/block_000/layer_000/SelfAttention/v_slot_v" - string_val: "decoder/block_000/layer_000/SelfAttention/o_slot_v" - string_val: "decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slot_v" - string_val: "decoder/block_000/layer_001/rms_norm/scale_slot_v" - string_val: "decoder/block_000/layer_001/EncDecAttention/q_slot_v" - string_val: "decoder/block_000/layer_001/EncDecAttention/k_slot_v" - string_val: "decoder/block_000/layer_001/EncDecAttention/v_slot_v" - string_val: "decoder/block_000/layer_001/EncDecAttention/o_slot_v" - string_val: "decoder/block_000/layer_002/rms_norm/scale_slot_v" - string_val: "decoder/block_000/layer_002/DenseReluDense/wi_0/kernel_slot_v" - string_val: "decoder/block_000/layer_002/DenseReluDense/wi_1/kernel_slot_v" - string_val: "decoder/block_000/layer_002/DenseReluDense/wo/kernel_slot_v" - string_val: "decoder/block_001/layer_000/rms_norm/scale_slot_v" - string_val: "decoder/block_001/layer_000/SelfAttention/q_slot_v" - string_val: "decoder/block_001/layer_000/SelfAttention/k_slot_v" - string_val: "decoder/block_001/layer_000/SelfAttention/v_slot_v" - string_val: "decoder/block_001/layer_000/SelfAttention/o_slot_v" - string_val: "decoder/block_001/layer_001/rms_norm/scale_slot_v" - string_val: "decoder/block_001/layer_001/EncDecAttention/q_slot_v" - string_val: "decoder/block_001/layer_001/EncDecAttention/k_slot_v" - string_val: "decoder/block_001/layer_001/EncDecAttention/v_slot_v" - string_val: "decoder/block_001/layer_001/EncDecAttention/o_slot_v" - string_val: "decoder/block_001/layer_002/rms_norm/scale_slot_v" - string_val: "decoder/block_001/layer_002/DenseReluDense/wi_0/kernel_slot_v" - string_val: "decoder/block_001/layer_002/DenseReluDense/wi_1/kernel_slot_v" - string_val: "decoder/block_001/layer_002/DenseReluDense/wo/kernel_slot_v" - string_val: "decoder/rms_norm/scale_slot_v" - string_val: "decoder/logits/kernel_slot_v" - } - } - } -} -node { - name: "report_uninitialized_variables_1/boolean_mask/Shape" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 105 - } - } - } -} -node { - name: "report_uninitialized_variables_1/boolean_mask/strided_slice/stack" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "report_uninitialized_variables_1/boolean_mask/strided_slice/stack_1" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } -} -node { - name: "report_uninitialized_variables_1/boolean_mask/strided_slice/stack_2" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } -} -node { - name: "report_uninitialized_variables_1/boolean_mask/strided_slice" - op: "StridedSlice" - input: "report_uninitialized_variables_1/boolean_mask/Shape" - input: "report_uninitialized_variables_1/boolean_mask/strided_slice/stack" - input: "report_uninitialized_variables_1/boolean_mask/strided_slice/stack_1" - input: "report_uninitialized_variables_1/boolean_mask/strided_slice/stack_2" - device: "/device:CPU:0" - attr { - key: "Index" - value { - type: DT_INT32 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "begin_mask" - value { - i: 0 - } - } - attr { - key: "ellipsis_mask" - value { - i: 0 - } - } - attr { - key: "end_mask" - value { - i: 0 - } - } - attr { - key: "new_axis_mask" - value { - i: 0 - } - } - attr { - key: "shrink_axis_mask" - value { - i: 0 - } - } -} -node { - name: "report_uninitialized_variables_1/boolean_mask/Prod/reduction_indices" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "report_uninitialized_variables_1/boolean_mask/Prod" - op: "Prod" - input: "report_uninitialized_variables_1/boolean_mask/strided_slice" - input: "report_uninitialized_variables_1/boolean_mask/Prod/reduction_indices" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "report_uninitialized_variables_1/boolean_mask/Shape_1" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 105 - } - } - } -} -node { - name: "report_uninitialized_variables_1/boolean_mask/strided_slice_1/stack" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "report_uninitialized_variables_1/boolean_mask/strided_slice_1/stack_1" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "report_uninitialized_variables_1/boolean_mask/strided_slice_1/stack_2" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } -} -node { - name: "report_uninitialized_variables_1/boolean_mask/strided_slice_1" - op: "StridedSlice" - input: "report_uninitialized_variables_1/boolean_mask/Shape_1" - input: "report_uninitialized_variables_1/boolean_mask/strided_slice_1/stack" - input: "report_uninitialized_variables_1/boolean_mask/strided_slice_1/stack_1" - input: "report_uninitialized_variables_1/boolean_mask/strided_slice_1/stack_2" - device: "/device:CPU:0" - attr { - key: "Index" - value { - type: DT_INT32 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - } - } - } - } - } - attr { - key: "begin_mask" - value { - i: 1 - } - } - attr { - key: "ellipsis_mask" - value { - i: 0 - } - } - attr { - key: "end_mask" - value { - i: 0 - } - } - attr { - key: "new_axis_mask" - value { - i: 0 - } - } - attr { - key: "shrink_axis_mask" - value { - i: 0 - } - } -} -node { - name: "report_uninitialized_variables_1/boolean_mask/Shape_2" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 105 - } - } - } -} -node { - name: "report_uninitialized_variables_1/boolean_mask/strided_slice_2/stack" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } -} -node { - name: "report_uninitialized_variables_1/boolean_mask/strided_slice_2/stack_1" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "report_uninitialized_variables_1/boolean_mask/strided_slice_2/stack_2" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } -} -node { - name: "report_uninitialized_variables_1/boolean_mask/strided_slice_2" - op: "StridedSlice" - input: "report_uninitialized_variables_1/boolean_mask/Shape_2" - input: "report_uninitialized_variables_1/boolean_mask/strided_slice_2/stack" - input: "report_uninitialized_variables_1/boolean_mask/strided_slice_2/stack_1" - input: "report_uninitialized_variables_1/boolean_mask/strided_slice_2/stack_2" - device: "/device:CPU:0" - attr { - key: "Index" - value { - type: DT_INT32 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - } - } - } - } - } - attr { - key: "begin_mask" - value { - i: 0 - } - } - attr { - key: "ellipsis_mask" - value { - i: 0 - } - } - attr { - key: "end_mask" - value { - i: 1 - } - } - attr { - key: "new_axis_mask" - value { - i: 0 - } - } - attr { - key: "shrink_axis_mask" - value { - i: 0 - } - } -} -node { - name: "report_uninitialized_variables_1/boolean_mask/concat/values_1" - op: "Pack" - input: "report_uninitialized_variables_1/boolean_mask/Prod" - device: "/device:CPU:0" - attr { - key: "N" - value { - i: 1 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "axis" - value { - i: 0 - } - } -} -node { - name: "report_uninitialized_variables_1/boolean_mask/concat/axis" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "report_uninitialized_variables_1/boolean_mask/concat" - op: "ConcatV2" - input: "report_uninitialized_variables_1/boolean_mask/strided_slice_1" - input: "report_uninitialized_variables_1/boolean_mask/concat/values_1" - input: "report_uninitialized_variables_1/boolean_mask/strided_slice_2" - input: "report_uninitialized_variables_1/boolean_mask/concat/axis" - device: "/device:CPU:0" - attr { - key: "N" - value { - i: 3 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } -} -node { - name: "report_uninitialized_variables_1/boolean_mask/Reshape" - op: "Reshape" - input: "report_uninitialized_variables_1/Const" - input: "report_uninitialized_variables_1/boolean_mask/concat" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_STRING - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 105 - } - } - } - } - } -} -node { - name: "report_uninitialized_variables_1/boolean_mask/Reshape_1/shape" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: -1 - } - } - } -} -node { - name: "report_uninitialized_variables_1/boolean_mask/Reshape_1" - op: "Reshape" - input: "report_uninitialized_variables_1/LogicalNot" - input: "report_uninitialized_variables_1/boolean_mask/Reshape_1/shape" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BOOL - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 105 - } - } - } - } - } -} -node { - name: "report_uninitialized_variables_1/boolean_mask/Where" - op: "Where" - input: "report_uninitialized_variables_1/boolean_mask/Reshape_1" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_BOOL - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "report_uninitialized_variables_1/boolean_mask/Squeeze" - op: "Squeeze" - input: "report_uninitialized_variables_1/boolean_mask/Where" - device: "/device:CPU:0" - attr { - key: "T" - value { - type: DT_INT64 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - } - } - } - attr { - key: "squeeze_dims" - value { - list { - i: 1 - } - } - } -} -node { - name: "report_uninitialized_variables_1/boolean_mask/GatherV2/axis" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "report_uninitialized_variables_1/boolean_mask/GatherV2" - op: "GatherV2" - input: "report_uninitialized_variables_1/boolean_mask/Reshape" - input: "report_uninitialized_variables_1/boolean_mask/Squeeze" - input: "report_uninitialized_variables_1/boolean_mask/GatherV2/axis" - device: "/device:CPU:0" - attr { - key: "Taxis" - value { - type: DT_INT32 - } - } - attr { - key: "Tindices" - value { - type: DT_INT64 - } - } - attr { - key: "Tparams" - value { - type: DT_STRING - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - } - } - } - attr { - key: "batch_dims" - value { - i: 0 - } - } -} -node { - name: "report_uninitialized_resources_1/Const" - op: "Const" - device: "/device:CPU:0" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - dim { - } - } - } - } - } -} -node { - name: "concat_1/axis" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "concat_1" - op: "ConcatV2" - input: "report_uninitialized_variables_1/boolean_mask/GatherV2" - input: "report_uninitialized_resources_1/Const" - input: "concat_1/axis" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_STRING - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - } - } - } -} -node { - name: "init_2" - op: "NoOp" -} -node { - name: "init_all_tables" - op: "NoOp" -} -node { - name: "init_3" - op: "NoOp" -} -node { - name: "group_deps_3" - op: "NoOp" - input: "^init_2" - input: "^init_3" - input: "^init_all_tables" -} -node { - name: "Merge/MergeSummary" - op: "MergeSummary" - input: "learning_rate" - input: "loss" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -library { - function { - signature { - name: "__inference_Dataset_map_custom_pack_batch_206" - input_arg { - name: "args_0" - type: DT_INT32 - } - input_arg { - name: "args_1" - type: DT_INT32 - } - output_arg { - name: "identity" - type: DT_INT32 - } - output_arg { - name: "identity_1" - type: DT_INT32 - } - output_arg { - name: "identity_2" - type: DT_INT32 - } - output_arg { - name: "identity_3" - type: DT_INT32 - } - output_arg { - name: "identity_4" - type: DT_INT32 - } - output_arg { - name: "identity_5" - type: DT_INT32 - } - } - node_def { - name: "Cast" - op: "Cast" - input: "args_0" - attr { - key: "DstT" - value { - type: DT_INT64 - } - } - attr { - key: "SrcT" - value { - type: DT_INT32 - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: -1 - } - } - } - } - } - experimental_debug_info { - original_node_names: "Cast" - } - } - node_def { - name: "Cast_1" - op: "Cast" - input: "args_1" - attr { - key: "DstT" - value { - type: DT_INT64 - } - } - attr { - key: "SrcT" - value { - type: DT_INT32 - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: -1 - } - } - } - } - } - experimental_debug_info { - original_node_names: "Cast_1" - } - } - node_def { - name: "PackSequences2/inputs_max_length" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 512 - } - } - } - experimental_debug_info { - original_node_names: "PackSequences2/inputs_max_length" - } - } - node_def { - name: "PackSequences2/targets_max_length" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 512 - } - } - } - experimental_debug_info { - original_node_names: "PackSequences2/targets_max_length" - } - } - node_def { - name: "PackSequences2" - op: "PackSequences2" - input: "Cast:y:0" - input: "Cast_1:y:0" - input: "PackSequences2/inputs_max_length:output:0" - input: "PackSequences2/targets_max_length:output:0" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: -1 - } - } - shape { - dim { - size: -1 - } - dim { - size: -1 - } - } - shape { - dim { - size: -1 - } - dim { - size: -1 - } - } - shape { - dim { - size: -1 - } - dim { - size: -1 - } - } - shape { - dim { - size: -1 - } - dim { - size: -1 - } - } - shape { - dim { - size: -1 - } - dim { - size: -1 - } - } - } - } - } - experimental_debug_info { - original_node_names: "PackSequences2" - } - } - node_def { - name: "Cast_2" - op: "Cast" - input: "PackSequences2:inputs_packed:0" - attr { - key: "DstT" - value { - type: DT_INT32 - } - } - attr { - key: "SrcT" - value { - type: DT_INT64 - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: -1 - } - } - } - } - } - experimental_debug_info { - original_node_names: "Cast_2" - } - } - node_def { - name: "Cast_3" - op: "Cast" - input: "PackSequences2:targets_packed:0" - attr { - key: "DstT" - value { - type: DT_INT32 - } - } - attr { - key: "SrcT" - value { - type: DT_INT64 - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: -1 - } - } - } - } - } - experimental_debug_info { - original_node_names: "Cast_3" - } - } - node_def { - name: "Identity" - op: "Identity" - input: "Cast_2:y:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: -1 - } - } - } - } - } - experimental_debug_info { - original_node_names: "Identity" - } - } - node_def { - name: "Identity_1" - op: "Identity" - input: "PackSequences2:inputs_position:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: -1 - } - } - } - } - } - experimental_debug_info { - original_node_names: "Identity_1" - } - } - node_def { - name: "Identity_2" - op: "Identity" - input: "PackSequences2:inputs_segmentation:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: -1 - } - } - } - } - } - experimental_debug_info { - original_node_names: "Identity_2" - } - } - node_def { - name: "Identity_3" - op: "Identity" - input: "Cast_3:y:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: -1 - } - } - } - } - } - experimental_debug_info { - original_node_names: "Identity_3" - } - } - node_def { - name: "Identity_4" - op: "Identity" - input: "PackSequences2:targets_position:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: -1 - } - } - } - } - } - experimental_debug_info { - original_node_names: "Identity_4" - } - } - node_def { - name: "Identity_5" - op: "Identity" - input: "PackSequences2:targets_segmentation:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: -1 - } - } - } - } - } - experimental_debug_info { - original_node_names: "Identity_5" - } - } - ret { - key: "identity" - value: "Identity:output:0" - } - ret { - key: "identity_1" - value: "Identity_1:output:0" - } - ret { - key: "identity_2" - value: "Identity_2:output:0" - } - ret { - key: "identity_3" - value: "Identity_3:output:0" - } - ret { - key: "identity_4" - value: "Identity_4:output:0" - } - ret { - key: "identity_5" - value: "Identity_5:output:0" - } - attr { - key: "_input_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: -1 - } - } - shape { - dim { - size: -1 - } - dim { - size: -1 - } - } - } - } - } - attr { - key: "_tf_data_function" - value { - b: true - } - } - arg_attr { - key: 0 - value { - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: -1 - } - } - } - } - } - attr { - key: "_user_specified_name" - value { - s: "args_0" - } - } - } - } - arg_attr { - key: 1 - value { - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: -1 - } - } - } - } - } - attr { - key: "_user_specified_name" - value { - s: "args_1" - } - } - } - } - } - function { - signature { - name: "__inference_Dataset_map_normalize_222" - input_arg { - name: "args_0" - type: DT_INT32 - } - input_arg { - name: "args_1" - type: DT_INT32 - } - input_arg { - name: "args_2" - type: DT_INT32 - } - input_arg { - name: "args_3" - type: DT_INT32 - } - input_arg { - name: "args_4" - type: DT_INT32 - } - input_arg { - name: "args_5" - type: DT_INT32 - } - output_arg { - name: "identity" - type: DT_INT32 - } - output_arg { - name: "identity_1" - type: DT_INT32 - } - output_arg { - name: "identity_2" - type: DT_INT32 - } - output_arg { - name: "identity_3" - type: DT_INT32 - } - output_arg { - name: "identity_4" - type: DT_INT32 - } - output_arg { - name: "identity_5" - type: DT_INT32 - } - } - node_def { - name: "Identity" - op: "Identity" - input: "args_0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: -1 - } - } - } - } - } - experimental_debug_info { - original_node_names: "Identity" - } - } - node_def { - name: "Identity_1" - op: "Identity" - input: "args_1" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: -1 - } - } - } - } - } - experimental_debug_info { - original_node_names: "Identity_1" - } - } - node_def { - name: "Identity_2" - op: "Identity" - input: "args_2" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: -1 - } - } - } - } - } - experimental_debug_info { - original_node_names: "Identity_2" - } - } - node_def { - name: "Identity_3" - op: "Identity" - input: "args_3" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: -1 - } - } - } - } - } - experimental_debug_info { - original_node_names: "Identity_3" - } - } - node_def { - name: "Identity_4" - op: "Identity" - input: "args_4" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: -1 - } - } - } - } - } - experimental_debug_info { - original_node_names: "Identity_4" - } - } - node_def { - name: "Identity_5" - op: "Identity" - input: "args_5" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: -1 - } - } - } - } - } - experimental_debug_info { - original_node_names: "Identity_5" - } - } - ret { - key: "identity" - value: "Identity:output:0" - } - ret { - key: "identity_1" - value: "Identity_1:output:0" - } - ret { - key: "identity_2" - value: "Identity_2:output:0" - } - ret { - key: "identity_3" - value: "Identity_3:output:0" - } - ret { - key: "identity_4" - value: "Identity_4:output:0" - } - ret { - key: "identity_5" - value: "Identity_5:output:0" - } - attr { - key: "_input_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: -1 - } - } - shape { - dim { - size: -1 - } - dim { - size: -1 - } - } - shape { - dim { - size: -1 - } - dim { - size: -1 - } - } - shape { - dim { - size: -1 - } - dim { - size: -1 - } - } - shape { - dim { - size: -1 - } - dim { - size: -1 - } - } - shape { - dim { - size: -1 - } - dim { - size: -1 - } - } - } - } - } - attr { - key: "_tf_data_function" - value { - b: true - } - } - arg_attr { - key: 0 - value { - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: -1 - } - } - } - } - } - attr { - key: "_user_specified_name" - value { - s: "args_0" - } - } - } - } - arg_attr { - key: 1 - value { - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: -1 - } - } - } - } - } - attr { - key: "_user_specified_name" - value { - s: "args_1" - } - } - } - } - arg_attr { - key: 2 - value { - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: -1 - } - } - } - } - } - attr { - key: "_user_specified_name" - value { - s: "args_2" - } - } - } - } - arg_attr { - key: 3 - value { - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: -1 - } - } - } - } - } - attr { - key: "_user_specified_name" - value { - s: "args_3" - } - } - } - } - arg_attr { - key: 4 - value { - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: -1 - } - } - } - } - } - attr { - key: "_user_specified_name" - value { - s: "args_4" - } - } - } - } - arg_attr { - key: 5 - value { - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: -1 - } - } - } - } - } - attr { - key: "_user_specified_name" - value { - s: "args_5" - } - } - } - } - } - function { - signature { - name: "__inference_Dataset_map_my_fn_250" - input_arg { - name: "args_0" - type: DT_INT32 - } - input_arg { - name: "args_1" - type: DT_INT32 - } - input_arg { - name: "args_2" - type: DT_INT32 - } - input_arg { - name: "args_3" - type: DT_INT32 - } - input_arg { - name: "args_4" - type: DT_INT32 - } - input_arg { - name: "args_5" - type: DT_INT32 - } - output_arg { - name: "identity" - type: DT_INT32 - } - output_arg { - name: "identity_1" - type: DT_INT32 - } - output_arg { - name: "identity_2" - type: DT_INT32 - } - output_arg { - name: "identity_3" - type: DT_INT32 - } - output_arg { - name: "identity_4" - type: DT_INT32 - } - output_arg { - name: "identity_5" - type: DT_INT32 - } - } - node_def { - name: "Reshape/shape" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 512 - } - } - } - experimental_debug_info { - original_node_names: "Reshape/shape" - } - } - node_def { - name: "Reshape" - op: "Reshape" - input: "args_0" - input: "Reshape/shape:output:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - } - } - } - } - experimental_debug_info { - original_node_names: "Reshape" - } - } - node_def { - name: "Reshape_1/shape" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 512 - } - } - } - experimental_debug_info { - original_node_names: "Reshape_1/shape" - } - } - node_def { - name: "Reshape_1" - op: "Reshape" - input: "args_2" - input: "Reshape_1/shape:output:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - } - } - } - } - experimental_debug_info { - original_node_names: "Reshape_1" - } - } - node_def { - name: "Reshape_2/shape" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 512 - } - } - } - experimental_debug_info { - original_node_names: "Reshape_2/shape" - } - } - node_def { - name: "Reshape_2" - op: "Reshape" - input: "args_1" - input: "Reshape_2/shape:output:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - } - } - } - } - experimental_debug_info { - original_node_names: "Reshape_2" - } - } - node_def { - name: "Reshape_3/shape" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 512 - } - } - } - experimental_debug_info { - original_node_names: "Reshape_3/shape" - } - } - node_def { - name: "Reshape_3" - op: "Reshape" - input: "args_3" - input: "Reshape_3/shape:output:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - } - } - } - } - experimental_debug_info { - original_node_names: "Reshape_3" - } - } - node_def { - name: "Reshape_4/shape" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 512 - } - } - } - experimental_debug_info { - original_node_names: "Reshape_4/shape" - } - } - node_def { - name: "Reshape_4" - op: "Reshape" - input: "args_5" - input: "Reshape_4/shape:output:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - } - } - } - } - experimental_debug_info { - original_node_names: "Reshape_4" - } - } - node_def { - name: "Reshape_5/shape" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 512 - } - } - } - experimental_debug_info { - original_node_names: "Reshape_5/shape" - } - } - node_def { - name: "Reshape_5" - op: "Reshape" - input: "args_4" - input: "Reshape_5/shape:output:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - } - } - } - } - experimental_debug_info { - original_node_names: "Reshape_5" - } - } - node_def { - name: "Identity" - op: "Identity" - input: "Reshape:output:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - } - } - } - } - experimental_debug_info { - original_node_names: "Identity" - } - } - node_def { - name: "Identity_1" - op: "Identity" - input: "Reshape_2:output:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - } - } - } - } - experimental_debug_info { - original_node_names: "Identity_1" - } - } - node_def { - name: "Identity_2" - op: "Identity" - input: "Reshape_1:output:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - } - } - } - } - experimental_debug_info { - original_node_names: "Identity_2" - } - } - node_def { - name: "Identity_3" - op: "Identity" - input: "Reshape_3:output:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - } - } - } - } - experimental_debug_info { - original_node_names: "Identity_3" - } - } - node_def { - name: "Identity_4" - op: "Identity" - input: "Reshape_5:output:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - } - } - } - } - experimental_debug_info { - original_node_names: "Identity_4" - } - } - node_def { - name: "Identity_5" - op: "Identity" - input: "Reshape_4:output:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - } - } - } - } - experimental_debug_info { - original_node_names: "Identity_5" - } - } - ret { - key: "identity" - value: "Identity:output:0" - } - ret { - key: "identity_1" - value: "Identity_1:output:0" - } - ret { - key: "identity_2" - value: "Identity_2:output:0" - } - ret { - key: "identity_3" - value: "Identity_3:output:0" - } - ret { - key: "identity_4" - value: "Identity_4:output:0" - } - ret { - key: "identity_5" - value: "Identity_5:output:0" - } - attr { - key: "_input_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - shape { - dim { - size: -1 - } - } - shape { - dim { - size: -1 - } - } - shape { - dim { - size: -1 - } - } - shape { - dim { - size: -1 - } - } - shape { - dim { - size: -1 - } - } - } - } - } - attr { - key: "_tf_data_function" - value { - b: true - } - } - arg_attr { - key: 0 - value { - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - } - } - } - attr { - key: "_user_specified_name" - value { - s: "args_0" - } - } - } - } - arg_attr { - key: 1 - value { - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - } - } - } - attr { - key: "_user_specified_name" - value { - s: "args_1" - } - } - } - } - arg_attr { - key: 2 - value { - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - } - } - } - attr { - key: "_user_specified_name" - value { - s: "args_2" - } - } - } - } - arg_attr { - key: 3 - value { - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - } - } - } - attr { - key: "_user_specified_name" - value { - s: "args_3" - } - } - } - } - arg_attr { - key: 4 - value { - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - } - } - } - attr { - key: "_user_specified_name" - value { - s: "args_4" - } - } - } - } - arg_attr { - key: 5 - value { - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - } - } - } - attr { - key: "_user_specified_name" - value { - s: "args_5" - } - } - } - } - } - function { - signature { - name: "__inference_Dataset_map_lambda_340" - input_arg { - name: "args_0" - type: DT_INT32 - } - input_arg { - name: "args_1" - type: DT_INT32 - } - input_arg { - name: "args_2" - type: DT_INT32 - } - input_arg { - name: "args_3" - type: DT_INT32 - } - input_arg { - name: "args_4" - type: DT_INT32 - } - input_arg { - name: "args_5" - type: DT_INT32 - } - output_arg { - name: "identity" - type: DT_INT32 - } - output_arg { - name: "identity_1" - type: DT_INT32 - } - output_arg { - name: "identity_2" - type: DT_INT32 - } - output_arg { - name: "identity_3" - type: DT_INT32 - } - output_arg { - name: "identity_4" - type: DT_INT32 - } - output_arg { - name: "identity_5" - type: DT_INT32 - } - } - node_def { - name: "strided_slice/stack" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } - experimental_debug_info { - original_node_names: "strided_slice/stack" - } - } - node_def { - name: "strided_slice/stack_1" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: -1 - } - } - } - experimental_debug_info { - original_node_names: "strided_slice/stack_1" - } - } - node_def { - name: "strided_slice/stack_2" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } - experimental_debug_info { - original_node_names: "strided_slice/stack_2" - } - } - node_def { - name: "strided_slice" - op: "StridedSlice" - input: "args_0" - input: "strided_slice/stack:output:0" - input: "strided_slice/stack_1:output:0" - input: "strided_slice/stack_2:output:0" - attr { - key: "Index" - value { - type: DT_INT32 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 511 - } - } - } - } - } - attr { - key: "begin_mask" - value { - i: 0 - } - } - attr { - key: "ellipsis_mask" - value { - i: 0 - } - } - attr { - key: "end_mask" - value { - i: 0 - } - } - attr { - key: "new_axis_mask" - value { - i: 0 - } - } - attr { - key: "shrink_axis_mask" - value { - i: 0 - } - } - experimental_debug_info { - original_node_names: "strided_slice" - } - } - node_def { - name: "strided_slice_1/stack" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: -1 - } - } - } - experimental_debug_info { - original_node_names: "strided_slice_1/stack" - } - } - node_def { - name: "strided_slice_1/stack_1" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } - experimental_debug_info { - original_node_names: "strided_slice_1/stack_1" - } - } - node_def { - name: "strided_slice_1/stack_2" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } - experimental_debug_info { - original_node_names: "strided_slice_1/stack_2" - } - } - node_def { - name: "strided_slice_1" - op: "StridedSlice" - input: "args_0" - input: "strided_slice_1/stack:output:0" - input: "strided_slice_1/stack_1:output:0" - input: "strided_slice_1/stack_2:output:0" - attr { - key: "Index" - value { - type: DT_INT32 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "begin_mask" - value { - i: 0 - } - } - attr { - key: "ellipsis_mask" - value { - i: 0 - } - } - attr { - key: "end_mask" - value { - i: 1 - } - } - attr { - key: "new_axis_mask" - value { - i: 0 - } - } - attr { - key: "shrink_axis_mask" - value { - i: 0 - } - } - experimental_debug_info { - original_node_names: "strided_slice_1" - } - } - node_def { - name: "clip_by_value/Minimum/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } - experimental_debug_info { - original_node_names: "clip_by_value/Minimum/y" - } - } - node_def { - name: "clip_by_value/Minimum" - op: "Minimum" - input: "strided_slice_1:output:0" - input: "clip_by_value/Minimum/y:output:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - experimental_debug_info { - original_node_names: "clip_by_value/Minimum" - } - } - node_def { - name: "clip_by_value/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } - experimental_debug_info { - original_node_names: "clip_by_value/y" - } - } - node_def { - name: "clip_by_value" - op: "Maximum" - input: "clip_by_value/Minimum:z:0" - input: "clip_by_value/y:output:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - experimental_debug_info { - original_node_names: "clip_by_value" - } - } - node_def { - name: "concat/axis" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } - experimental_debug_info { - original_node_names: "concat/axis" - } - } - node_def { - name: "concat" - op: "ConcatV2" - input: "strided_slice:output:0" - input: "clip_by_value:z:0" - input: "concat/axis:output:0" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - } - } - } - } - experimental_debug_info { - original_node_names: "concat" - } - } - node_def { - name: "strided_slice_2/stack" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } - experimental_debug_info { - original_node_names: "strided_slice_2/stack" - } - } - node_def { - name: "strided_slice_2/stack_1" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: -1 - } - } - } - experimental_debug_info { - original_node_names: "strided_slice_2/stack_1" - } - } - node_def { - name: "strided_slice_2/stack_2" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } - experimental_debug_info { - original_node_names: "strided_slice_2/stack_2" - } - } - node_def { - name: "strided_slice_2" - op: "StridedSlice" - input: "args_3" - input: "strided_slice_2/stack:output:0" - input: "strided_slice_2/stack_1:output:0" - input: "strided_slice_2/stack_2:output:0" - attr { - key: "Index" - value { - type: DT_INT32 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 511 - } - } - } - } - } - attr { - key: "begin_mask" - value { - i: 0 - } - } - attr { - key: "ellipsis_mask" - value { - i: 0 - } - } - attr { - key: "end_mask" - value { - i: 0 - } - } - attr { - key: "new_axis_mask" - value { - i: 0 - } - } - attr { - key: "shrink_axis_mask" - value { - i: 0 - } - } - experimental_debug_info { - original_node_names: "strided_slice_2" - } - } - node_def { - name: "strided_slice_3/stack" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: -1 - } - } - } - experimental_debug_info { - original_node_names: "strided_slice_3/stack" - } - } - node_def { - name: "strided_slice_3/stack_1" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } - experimental_debug_info { - original_node_names: "strided_slice_3/stack_1" - } - } - node_def { - name: "strided_slice_3/stack_2" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } - experimental_debug_info { - original_node_names: "strided_slice_3/stack_2" - } - } - node_def { - name: "strided_slice_3" - op: "StridedSlice" - input: "args_3" - input: "strided_slice_3/stack:output:0" - input: "strided_slice_3/stack_1:output:0" - input: "strided_slice_3/stack_2:output:0" - attr { - key: "Index" - value { - type: DT_INT32 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "begin_mask" - value { - i: 0 - } - } - attr { - key: "ellipsis_mask" - value { - i: 0 - } - } - attr { - key: "end_mask" - value { - i: 1 - } - } - attr { - key: "new_axis_mask" - value { - i: 0 - } - } - attr { - key: "shrink_axis_mask" - value { - i: 0 - } - } - experimental_debug_info { - original_node_names: "strided_slice_3" - } - } - node_def { - name: "clip_by_value_1/Minimum/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } - experimental_debug_info { - original_node_names: "clip_by_value_1/Minimum/y" - } - } - node_def { - name: "clip_by_value_1/Minimum" - op: "Minimum" - input: "strided_slice_3:output:0" - input: "clip_by_value_1/Minimum/y:output:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - experimental_debug_info { - original_node_names: "clip_by_value_1/Minimum" - } - } - node_def { - name: "clip_by_value_1/y" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } - experimental_debug_info { - original_node_names: "clip_by_value_1/y" - } - } - node_def { - name: "clip_by_value_1" - op: "Maximum" - input: "clip_by_value_1/Minimum:z:0" - input: "clip_by_value_1/y:output:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - experimental_debug_info { - original_node_names: "clip_by_value_1" - } - } - node_def { - name: "concat_1/axis" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } - experimental_debug_info { - original_node_names: "concat_1/axis" - } - } - node_def { - name: "concat_1" - op: "ConcatV2" - input: "strided_slice_2:output:0" - input: "clip_by_value_1:z:0" - input: "concat_1/axis:output:0" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - } - } - } - } - experimental_debug_info { - original_node_names: "concat_1" - } - } - node_def { - name: "Identity" - op: "Identity" - input: "concat:output:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - } - } - } - } - experimental_debug_info { - original_node_names: "Identity" - } - } - node_def { - name: "Identity_1" - op: "Identity" - input: "args_1" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - } - } - } - } - experimental_debug_info { - original_node_names: "Identity_1" - } - } - node_def { - name: "Identity_2" - op: "Identity" - input: "args_2" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - } - } - } - } - experimental_debug_info { - original_node_names: "Identity_2" - } - } - node_def { - name: "Identity_3" - op: "Identity" - input: "concat_1:output:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - } - } - } - } - experimental_debug_info { - original_node_names: "Identity_3" - } - } - node_def { - name: "Identity_4" - op: "Identity" - input: "args_4" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - } - } - } - } - experimental_debug_info { - original_node_names: "Identity_4" - } - } - node_def { - name: "Identity_5" - op: "Identity" - input: "args_5" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - } - } - } - } - experimental_debug_info { - original_node_names: "Identity_5" - } - } - ret { - key: "identity" - value: "Identity:output:0" - } - ret { - key: "identity_1" - value: "Identity_1:output:0" - } - ret { - key: "identity_2" - value: "Identity_2:output:0" - } - ret { - key: "identity_3" - value: "Identity_3:output:0" - } - ret { - key: "identity_4" - value: "Identity_4:output:0" - } - ret { - key: "identity_5" - value: "Identity_5:output:0" - } - attr { - key: "_input_shapes" - value { - list { - shape { - dim { - size: 512 - } - } - shape { - dim { - size: 512 - } - } - shape { - dim { - size: 512 - } - } - shape { - dim { - size: 512 - } - } - shape { - dim { - size: 512 - } - } - shape { - dim { - size: 512 - } - } - } - } - } - attr { - key: "_tf_data_function" - value { - b: true - } - } - arg_attr { - key: 0 - value { - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - } - } - } - } - attr { - key: "_user_specified_name" - value { - s: "args_0" - } - } - } - } - arg_attr { - key: 1 - value { - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - } - } - } - } - attr { - key: "_user_specified_name" - value { - s: "args_1" - } - } - } - } - arg_attr { - key: 2 - value { - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - } - } - } - } - attr { - key: "_user_specified_name" - value { - s: "args_2" - } - } - } - } - arg_attr { - key: 3 - value { - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - } - } - } - } - attr { - key: "_user_specified_name" - value { - s: "args_3" - } - } - } - } - arg_attr { - key: 4 - value { - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - } - } - } - } - attr { - key: "_user_specified_name" - value { - s: "args_4" - } - } - } - } - arg_attr { - key: 5 - value { - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - } - } - } - } - attr { - key: "_user_specified_name" - value { - s: "args_5" - } - } - } - } - } - function { - signature { - name: "__inference_Dataset_map__filter_features_159" - input_arg { - name: "args_0" - type: DT_INT32 - } - input_arg { - name: "args_1" - type: DT_STRING - } - input_arg { - name: "args_2" - type: DT_INT32 - } - input_arg { - name: "args_3" - type: DT_STRING - } - output_arg { - name: "identity" - type: DT_INT32 - } - output_arg { - name: "identity_1" - type: DT_INT32 - } - } - node_def { - name: "Identity" - op: "Identity" - input: "args_0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - } - } - } - experimental_debug_info { - original_node_names: "Identity" - } - } - node_def { - name: "Identity_1" - op: "Identity" - input: "args_2" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - } - } - } - experimental_debug_info { - original_node_names: "Identity_1" - } - } - ret { - key: "identity" - value: "Identity:output:0" - } - ret { - key: "identity_1" - value: "Identity_1:output:0" - } - attr { - key: "_input_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - shape { - } - shape { - dim { - size: -1 - } - } - shape { - } - } - } - } - attr { - key: "_tf_data_function" - value { - b: true - } - } - arg_attr { - key: 0 - value { - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - } - } - } - attr { - key: "_user_specified_name" - value { - s: "args_0" - } - } - } - } - arg_attr { - key: 1 - value { - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "_user_specified_name" - value { - s: "args_1" - } - } - } - } - arg_attr { - key: 2 - value { - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - } - } - } - attr { - key: "_user_specified_name" - value { - s: "args_2" - } - } - } - } - arg_attr { - key: 3 - value { - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "_user_specified_name" - value { - s: "args_3" - } - } - } - } - } - function { - signature { - name: "__inference_Dataset_interleave_read_file_fn_88" - input_arg { - name: "args_0" - type: DT_STRING - } - output_arg { - name: "identity" - type: DT_VARIANT - } - is_stateful: true - control_output: "TensorSliceDataset" - } - node_def { - name: "flat_filenames/shape" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: -1 - } - } - } - experimental_debug_info { - original_node_names: "flat_filenames/shape" - } - } - node_def { - name: "flat_filenames" - op: "Reshape" - input: "args_0" - input: "flat_filenames/shape:output:0" - attr { - key: "T" - value { - type: DT_STRING - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - experimental_debug_info { - original_node_names: "flat_filenames" - } - } - node_def { - name: "TensorSliceDataset" - op: "TensorSliceDataset" - input: "flat_filenames:output:0" - attr { - key: "Toutput_types" - value { - list { - type: DT_STRING - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "output_shapes" - value { - list { - shape { - } - } - } - } - experimental_debug_info { - original_node_names: "TensorSliceDataset" - } - } - node_def { - name: "FlatMapDataset" - op: "FlatMapDataset" - input: "TensorSliceDataset:handle:0" - attr { - key: "Targuments" - value { - list { - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "f" - value { - func { - name: "__inference_Dataset_flat_map_read_one_file_38" - attr { - key: "_tf_data_function" - value { - b: true - } - } - } - } - } - attr { - key: "output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "output_types" - value { - list { - type: DT_STRING - } - } - } - experimental_debug_info { - original_node_names: "FlatMapDataset" - } - } - node_def { - name: "num_parallel_calls" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: -1 - } - } - } - experimental_debug_info { - original_node_names: "num_parallel_calls" - } - } - node_def { - name: "ParallelMapDatasetV2" - op: "ParallelMapDatasetV2" - input: "FlatMapDataset:handle:0" - input: "num_parallel_calls:output:0" - attr { - key: "Targuments" - value { - list { - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "deterministic" - value { - s: "default" - } - } - attr { - key: "f" - value { - func { - name: "__inference_Dataset_map_lambda_58" - attr { - key: "_tf_data_function" - value { - b: true - } - } - } - } - } - attr { - key: "output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - shape { - } - shape { - dim { - size: -1 - } - } - shape { - } - } - } - } - attr { - key: "output_types" - value { - list { - type: DT_INT64 - type: DT_STRING - type: DT_INT64 - type: DT_STRING - } - } - } - attr { - key: "preserve_cardinality" - value { - b: true - } - } - attr { - key: "use_inter_op_parallelism" - value { - b: true - } - } - experimental_debug_info { - original_node_names: "ParallelMapDatasetV2" - } - } - node_def { - name: "num_parallel_calls_1" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: -1 - } - } - } - experimental_debug_info { - original_node_names: "num_parallel_calls_1" - } - } - node_def { - name: "ParallelMapDatasetV2_1" - op: "ParallelMapDatasetV2" - input: "ParallelMapDatasetV2:handle:0" - input: "num_parallel_calls_1:output:0" - attr { - key: "Targuments" - value { - list { - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "deterministic" - value { - s: "default" - } - } - attr { - key: "f" - value { - func { - name: "__inference_Dataset_map_lambda_72" - attr { - key: "_tf_data_function" - value { - b: true - } - } - } - } - } - attr { - key: "output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - shape { - } - shape { - dim { - size: -1 - } - } - shape { - } - } - } - } - attr { - key: "output_types" - value { - list { - type: DT_INT32 - type: DT_STRING - type: DT_INT32 - type: DT_STRING - } - } - } - attr { - key: "preserve_cardinality" - value { - b: true - } - } - attr { - key: "use_inter_op_parallelism" - value { - b: true - } - } - experimental_debug_info { - original_node_names: "ParallelMapDatasetV2_1" - } - } - node_def { - name: "num_parallel_calls_2" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: -1 - } - } - } - experimental_debug_info { - original_node_names: "num_parallel_calls_2" - } - } - node_def { - name: "ParallelMapDatasetV2_2" - op: "ParallelMapDatasetV2" - input: "ParallelMapDatasetV2_1:handle:0" - input: "num_parallel_calls_2:output:0" - attr { - key: "Targuments" - value { - list { - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "deterministic" - value { - s: "default" - } - } - attr { - key: "f" - value { - func { - name: "__inference_Dataset_map__rename_84" - attr { - key: "_tf_data_function" - value { - b: true - } - } - } - } - } - attr { - key: "output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - shape { - } - shape { - dim { - size: -1 - } - } - shape { - } - } - } - } - attr { - key: "output_types" - value { - list { - type: DT_INT32 - type: DT_STRING - type: DT_INT32 - type: DT_STRING - } - } - } - attr { - key: "preserve_cardinality" - value { - b: true - } - } - attr { - key: "use_inter_op_parallelism" - value { - b: true - } - } - experimental_debug_info { - original_node_names: "ParallelMapDatasetV2_2" - } - } - node_def { - name: "Identity" - op: "Identity" - input: "ParallelMapDatasetV2_2:handle:0" - input: "^TensorSliceDataset" - attr { - key: "T" - value { - type: DT_VARIANT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - experimental_debug_info { - original_node_names: "Identity" - } - } - ret { - key: "identity" - value: "Identity:output:0" - } - attr { - key: "_input_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "_tf_data_function" - value { - b: true - } - } - control_ret { - key: "TensorSliceDataset" - value: "TensorSliceDataset" - } - arg_attr { - key: 0 - value { - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "_user_specified_name" - value { - s: "args_0" - } - } - } - } - } - function { - signature { - name: "__inference_Dataset_map_lambda_296" - input_arg { - name: "args_0" - type: DT_INT32 - } - input_arg { - name: "args_1" - type: DT_INT32 - } - input_arg { - name: "args_2" - type: DT_INT32 - } - input_arg { - name: "args_3" - type: DT_INT32 - } - input_arg { - name: "args_4" - type: DT_INT32 - } - input_arg { - name: "args_5" - type: DT_INT32 - } - output_arg { - name: "identity" - type: DT_INT32 - } - output_arg { - name: "identity_1" - type: DT_INT32 - } - output_arg { - name: "identity_2" - type: DT_INT32 - } - output_arg { - name: "identity_3" - type: DT_INT32 - } - output_arg { - name: "identity_4" - type: DT_INT32 - } - output_arg { - name: "identity_5" - type: DT_INT32 - } - } - node_def { - name: "strided_slice/stack" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } - experimental_debug_info { - original_node_names: "strided_slice/stack" - } - } - node_def { - name: "strided_slice/stack_1" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 512 - } - } - } - experimental_debug_info { - original_node_names: "strided_slice/stack_1" - } - } - node_def { - name: "strided_slice/stack_2" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } - experimental_debug_info { - original_node_names: "strided_slice/stack_2" - } - } - node_def { - name: "strided_slice" - op: "StridedSlice" - input: "args_0" - input: "strided_slice/stack:output:0" - input: "strided_slice/stack_1:output:0" - input: "strided_slice/stack_2:output:0" - attr { - key: "Index" - value { - type: DT_INT32 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - } - } - } - } - attr { - key: "begin_mask" - value { - i: 1 - } - } - attr { - key: "ellipsis_mask" - value { - i: 0 - } - } - attr { - key: "end_mask" - value { - i: 0 - } - } - attr { - key: "new_axis_mask" - value { - i: 0 - } - } - attr { - key: "shrink_axis_mask" - value { - i: 0 - } - } - experimental_debug_info { - original_node_names: "strided_slice" - } - } - node_def { - name: "Shape" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 512 - } - } - } - experimental_debug_info { - original_node_names: "Shape" - } - } - node_def { - name: "strided_slice_1/stack" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } - experimental_debug_info { - original_node_names: "strided_slice_1/stack" - } - } - node_def { - name: "strided_slice_1/stack_1" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } - experimental_debug_info { - original_node_names: "strided_slice_1/stack_1" - } - } - node_def { - name: "strided_slice_1/stack_2" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } - experimental_debug_info { - original_node_names: "strided_slice_1/stack_2" - } - } - node_def { - name: "strided_slice_1" - op: "StridedSlice" - input: "Shape:output:0" - input: "strided_slice_1/stack:output:0" - input: "strided_slice_1/stack_1:output:0" - input: "strided_slice_1/stack_2:output:0" - attr { - key: "Index" - value { - type: DT_INT32 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "begin_mask" - value { - i: 0 - } - } - attr { - key: "ellipsis_mask" - value { - i: 0 - } - } - attr { - key: "end_mask" - value { - i: 0 - } - } - attr { - key: "new_axis_mask" - value { - i: 0 - } - } - attr { - key: "shrink_axis_mask" - value { - i: 1 - } - } - experimental_debug_info { - original_node_names: "strided_slice_1" - } - } - node_def { - name: "sub/x" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 512 - } - } - } - experimental_debug_info { - original_node_names: "sub/x" - } - } - node_def { - name: "sub" - op: "Sub" - input: "sub/x:output:0" - input: "strided_slice_1:output:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - experimental_debug_info { - original_node_names: "sub" - } - } - node_def { - name: "Pad/paddings/0/0" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } - experimental_debug_info { - original_node_names: "Pad/paddings/0/0" - } - } - node_def { - name: "Pad/paddings/0" - op: "Pack" - input: "Pad/paddings/0/0:output:0" - input: "sub:z:0" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "axis" - value { - i: 0 - } - } - experimental_debug_info { - original_node_names: "Pad/paddings/0" - } - } - node_def { - name: "Pad/paddings" - op: "Pack" - input: "Pad/paddings/0:output:0" - attr { - key: "N" - value { - i: 1 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 2 - } - } - } - } - } - attr { - key: "axis" - value { - i: 0 - } - } - experimental_debug_info { - original_node_names: "Pad/paddings" - } - } - node_def { - name: "Pad" - op: "Pad" - input: "strided_slice:output:0" - input: "Pad/paddings:output:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tpaddings" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - } - } - } - } - experimental_debug_info { - original_node_names: "Pad" - } - } - node_def { - name: "strided_slice_2/stack" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } - experimental_debug_info { - original_node_names: "strided_slice_2/stack" - } - } - node_def { - name: "strided_slice_2/stack_1" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 512 - } - } - } - experimental_debug_info { - original_node_names: "strided_slice_2/stack_1" - } - } - node_def { - name: "strided_slice_2/stack_2" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } - experimental_debug_info { - original_node_names: "strided_slice_2/stack_2" - } - } - node_def { - name: "strided_slice_2" - op: "StridedSlice" - input: "args_3" - input: "strided_slice_2/stack:output:0" - input: "strided_slice_2/stack_1:output:0" - input: "strided_slice_2/stack_2:output:0" - attr { - key: "Index" - value { - type: DT_INT32 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - } - } - } - } - attr { - key: "begin_mask" - value { - i: 1 - } - } - attr { - key: "ellipsis_mask" - value { - i: 0 - } - } - attr { - key: "end_mask" - value { - i: 0 - } - } - attr { - key: "new_axis_mask" - value { - i: 0 - } - } - attr { - key: "shrink_axis_mask" - value { - i: 0 - } - } - experimental_debug_info { - original_node_names: "strided_slice_2" - } - } - node_def { - name: "Shape_1" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 512 - } - } - } - experimental_debug_info { - original_node_names: "Shape_1" - } - } - node_def { - name: "strided_slice_3/stack" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } - experimental_debug_info { - original_node_names: "strided_slice_3/stack" - } - } - node_def { - name: "strided_slice_3/stack_1" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } - experimental_debug_info { - original_node_names: "strided_slice_3/stack_1" - } - } - node_def { - name: "strided_slice_3/stack_2" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } - experimental_debug_info { - original_node_names: "strided_slice_3/stack_2" - } - } - node_def { - name: "strided_slice_3" - op: "StridedSlice" - input: "Shape_1:output:0" - input: "strided_slice_3/stack:output:0" - input: "strided_slice_3/stack_1:output:0" - input: "strided_slice_3/stack_2:output:0" - attr { - key: "Index" - value { - type: DT_INT32 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "begin_mask" - value { - i: 0 - } - } - attr { - key: "ellipsis_mask" - value { - i: 0 - } - } - attr { - key: "end_mask" - value { - i: 0 - } - } - attr { - key: "new_axis_mask" - value { - i: 0 - } - } - attr { - key: "shrink_axis_mask" - value { - i: 1 - } - } - experimental_debug_info { - original_node_names: "strided_slice_3" - } - } - node_def { - name: "sub_1/x" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 512 - } - } - } - experimental_debug_info { - original_node_names: "sub_1/x" - } - } - node_def { - name: "sub_1" - op: "Sub" - input: "sub_1/x:output:0" - input: "strided_slice_3:output:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - experimental_debug_info { - original_node_names: "sub_1" - } - } - node_def { - name: "Pad_1/paddings/0/0" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } - experimental_debug_info { - original_node_names: "Pad_1/paddings/0/0" - } - } - node_def { - name: "Pad_1/paddings/0" - op: "Pack" - input: "Pad_1/paddings/0/0:output:0" - input: "sub_1:z:0" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "axis" - value { - i: 0 - } - } - experimental_debug_info { - original_node_names: "Pad_1/paddings/0" - } - } - node_def { - name: "Pad_1/paddings" - op: "Pack" - input: "Pad_1/paddings/0:output:0" - attr { - key: "N" - value { - i: 1 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - dim { - size: 2 - } - } - } - } - } - attr { - key: "axis" - value { - i: 0 - } - } - experimental_debug_info { - original_node_names: "Pad_1/paddings" - } - } - node_def { - name: "Pad_1" - op: "Pad" - input: "strided_slice_2:output:0" - input: "Pad_1/paddings:output:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tpaddings" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - } - } - } - } - experimental_debug_info { - original_node_names: "Pad_1" - } - } - node_def { - name: "Identity" - op: "Identity" - input: "Pad:output:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - } - } - } - } - experimental_debug_info { - original_node_names: "Identity" - } - } - node_def { - name: "Identity_1" - op: "Identity" - input: "args_1" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - } - } - } - } - experimental_debug_info { - original_node_names: "Identity_1" - } - } - node_def { - name: "Identity_2" - op: "Identity" - input: "args_2" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - } - } - } - } - experimental_debug_info { - original_node_names: "Identity_2" - } - } - node_def { - name: "Identity_3" - op: "Identity" - input: "Pad_1:output:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - } - } - } - } - experimental_debug_info { - original_node_names: "Identity_3" - } - } - node_def { - name: "Identity_4" - op: "Identity" - input: "args_4" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - } - } - } - } - experimental_debug_info { - original_node_names: "Identity_4" - } - } - node_def { - name: "Identity_5" - op: "Identity" - input: "args_5" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - } - } - } - } - experimental_debug_info { - original_node_names: "Identity_5" - } - } - ret { - key: "identity" - value: "Identity:output:0" - } - ret { - key: "identity_1" - value: "Identity_1:output:0" - } - ret { - key: "identity_2" - value: "Identity_2:output:0" - } - ret { - key: "identity_3" - value: "Identity_3:output:0" - } - ret { - key: "identity_4" - value: "Identity_4:output:0" - } - ret { - key: "identity_5" - value: "Identity_5:output:0" - } - attr { - key: "_input_shapes" - value { - list { - shape { - dim { - size: 512 - } - } - shape { - dim { - size: 512 - } - } - shape { - dim { - size: 512 - } - } - shape { - dim { - size: 512 - } - } - shape { - dim { - size: 512 - } - } - shape { - dim { - size: 512 - } - } - } - } - } - attr { - key: "_tf_data_function" - value { - b: true - } - } - arg_attr { - key: 0 - value { - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - } - } - } - } - attr { - key: "_user_specified_name" - value { - s: "args_0" - } - } - } - } - arg_attr { - key: 1 - value { - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - } - } - } - } - attr { - key: "_user_specified_name" - value { - s: "args_1" - } - } - } - } - arg_attr { - key: 2 - value { - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - } - } - } - } - attr { - key: "_user_specified_name" - value { - s: "args_2" - } - } - } - } - arg_attr { - key: 3 - value { - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - } - } - } - } - attr { - key: "_user_specified_name" - value { - s: "args_3" - } - } - } - } - arg_attr { - key: 4 - value { - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - } - } - } - } - attr { - key: "_user_specified_name" - value { - s: "args_4" - } - } - } - } - arg_attr { - key: 5 - value { - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 512 - } - } - } - } - } - attr { - key: "_user_specified_name" - value { - s: "args_5" - } - } - } - } - } - function { - signature { - name: "__inference_Dataset_map_lambda_58" - input_arg { - name: "args_0" - type: DT_STRING - } - output_arg { - name: "identity" - type: DT_INT64 - } - output_arg { - name: "identity_1" - type: DT_STRING - } - output_arg { - name: "identity_2" - type: DT_INT64 - } - output_arg { - name: "identity_3" - type: DT_STRING - } - } - node_def { - name: "ParseSingleExample/ParseExample/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: 0 - } - } - } - experimental_debug_info { - original_node_names: "ParseSingleExample/ParseExample/Const" - } - } - node_def { - name: "ParseSingleExample/ParseExample/Const_1" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - dim { - } - } - } - } - } - experimental_debug_info { - original_node_names: "ParseSingleExample/ParseExample/Const_1" - } - } - node_def { - name: "ParseSingleExample/ParseExample/Const_2" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: 0 - } - } - } - experimental_debug_info { - original_node_names: "ParseSingleExample/ParseExample/Const_2" - } - } - node_def { - name: "ParseSingleExample/ParseExample/Const_3" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - dim { - } - } - } - } - } - experimental_debug_info { - original_node_names: "ParseSingleExample/ParseExample/Const_3" - } - } - node_def { - name: "ParseSingleExample/ParseExample/ParseExampleV2/names" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - dim { - } - } - } - } - } - experimental_debug_info { - original_node_names: "ParseSingleExample/ParseExample/ParseExampleV2/names" - } - } - node_def { - name: "ParseSingleExample/ParseExample/ParseExampleV2/sparse_keys" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - dim { - } - } - } - } - } - experimental_debug_info { - original_node_names: "ParseSingleExample/ParseExample/ParseExampleV2/sparse_keys" - } - } - node_def { - name: "ParseSingleExample/ParseExample/ParseExampleV2/dense_keys" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 4 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - dim { - size: 4 - } - } - string_val: "inputs" - string_val: "inputs_plaintext" - string_val: "targets" - string_val: "targets_plaintext" - } - } - } - experimental_debug_info { - original_node_names: "ParseSingleExample/ParseExample/ParseExampleV2/dense_keys" - } - } - node_def { - name: "ParseSingleExample/ParseExample/ParseExampleV2/ragged_keys" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - dim { - } - } - } - } - } - experimental_debug_info { - original_node_names: "ParseSingleExample/ParseExample/ParseExampleV2/ragged_keys" - } - } - node_def { - name: "ParseSingleExample/ParseExample/ParseExampleV2" - op: "ParseExampleV2" - input: "args_0" - input: "ParseSingleExample/ParseExample/ParseExampleV2/names:output:0" - input: "ParseSingleExample/ParseExample/ParseExampleV2/sparse_keys:output:0" - input: "ParseSingleExample/ParseExample/ParseExampleV2/dense_keys:output:0" - input: "ParseSingleExample/ParseExample/ParseExampleV2/ragged_keys:output:0" - input: "ParseSingleExample/ParseExample/Const:output:0" - input: "ParseSingleExample/ParseExample/Const_1:output:0" - input: "ParseSingleExample/ParseExample/Const_2:output:0" - input: "ParseSingleExample/ParseExample/Const_3:output:0" - attr { - key: "Tdense" - value { - list { - type: DT_INT64 - type: DT_STRING - type: DT_INT64 - type: DT_STRING - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - shape { - } - shape { - dim { - size: -1 - } - } - shape { - } - } - } - } - attr { - key: "dense_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - shape { - } - shape { - dim { - size: -1 - } - } - shape { - } - } - } - } - attr { - key: "num_sparse" - value { - i: 0 - } - } - attr { - key: "ragged_split_types" - value { - list { - } - } - } - attr { - key: "ragged_value_types" - value { - list { - } - } - } - attr { - key: "sparse_types" - value { - list { - } - } - } - experimental_debug_info { - original_node_names: "ParseSingleExample/ParseExample/ParseExampleV2" - } - } - node_def { - name: "Identity" - op: "Identity" - input: "ParseSingleExample/ParseExample/ParseExampleV2:dense_values:0" - attr { - key: "T" - value { - type: DT_INT64 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - } - } - } - experimental_debug_info { - original_node_names: "Identity" - } - } - node_def { - name: "Identity_1" - op: "Identity" - input: "ParseSingleExample/ParseExample/ParseExampleV2:dense_values:1" - attr { - key: "T" - value { - type: DT_STRING - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - experimental_debug_info { - original_node_names: "Identity_1" - } - } - node_def { - name: "Identity_2" - op: "Identity" - input: "ParseSingleExample/ParseExample/ParseExampleV2:dense_values:2" - attr { - key: "T" - value { - type: DT_INT64 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - } - } - } - experimental_debug_info { - original_node_names: "Identity_2" - } - } - node_def { - name: "Identity_3" - op: "Identity" - input: "ParseSingleExample/ParseExample/ParseExampleV2:dense_values:3" - attr { - key: "T" - value { - type: DT_STRING - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - experimental_debug_info { - original_node_names: "Identity_3" - } - } - ret { - key: "identity" - value: "Identity:output:0" - } - ret { - key: "identity_1" - value: "Identity_1:output:0" - } - ret { - key: "identity_2" - value: "Identity_2:output:0" - } - ret { - key: "identity_3" - value: "Identity_3:output:0" - } - attr { - key: "_input_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "_tf_data_function" - value { - b: true - } - } - arg_attr { - key: 0 - value { - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "_user_specified_name" - value { - s: "args_0" - } - } - } - } - } - function { - signature { - name: "__inference_Dataset_map_lambda_72" - input_arg { - name: "args_0" - type: DT_INT64 - } - input_arg { - name: "args_1" - type: DT_STRING - } - input_arg { - name: "args_2" - type: DT_INT64 - } - input_arg { - name: "args_3" - type: DT_STRING - } - output_arg { - name: "identity" - type: DT_INT32 - } - output_arg { - name: "identity_1" - type: DT_STRING - } - output_arg { - name: "identity_2" - type: DT_INT32 - } - output_arg { - name: "identity_3" - type: DT_STRING - } - } - node_def { - name: "Cast" - op: "Cast" - input: "args_0" - attr { - key: "DstT" - value { - type: DT_INT32 - } - } - attr { - key: "SrcT" - value { - type: DT_INT64 - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - } - } - } - experimental_debug_info { - original_node_names: "Cast" - } - } - node_def { - name: "Cast_1" - op: "Cast" - input: "args_2" - attr { - key: "DstT" - value { - type: DT_INT32 - } - } - attr { - key: "SrcT" - value { - type: DT_INT64 - } - } - attr { - key: "Truncate" - value { - b: false - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - } - } - } - experimental_debug_info { - original_node_names: "Cast_1" - } - } - node_def { - name: "Identity" - op: "Identity" - input: "Cast:y:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - } - } - } - experimental_debug_info { - original_node_names: "Identity" - } - } - node_def { - name: "Identity_1" - op: "Identity" - input: "args_1" - attr { - key: "T" - value { - type: DT_STRING - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - experimental_debug_info { - original_node_names: "Identity_1" - } - } - node_def { - name: "Identity_2" - op: "Identity" - input: "Cast_1:y:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - } - } - } - experimental_debug_info { - original_node_names: "Identity_2" - } - } - node_def { - name: "Identity_3" - op: "Identity" - input: "args_3" - attr { - key: "T" - value { - type: DT_STRING - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - experimental_debug_info { - original_node_names: "Identity_3" - } - } - ret { - key: "identity" - value: "Identity:output:0" - } - ret { - key: "identity_1" - value: "Identity_1:output:0" - } - ret { - key: "identity_2" - value: "Identity_2:output:0" - } - ret { - key: "identity_3" - value: "Identity_3:output:0" - } - attr { - key: "_input_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - shape { - } - shape { - dim { - size: -1 - } - } - shape { - } - } - } - } - attr { - key: "_tf_data_function" - value { - b: true - } - } - arg_attr { - key: 0 - value { - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - } - } - } - attr { - key: "_user_specified_name" - value { - s: "args_0" - } - } - } - } - arg_attr { - key: 1 - value { - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "_user_specified_name" - value { - s: "args_1" - } - } - } - } - arg_attr { - key: 2 - value { - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - } - } - } - attr { - key: "_user_specified_name" - value { - s: "args_2" - } - } - } - } - arg_attr { - key: 3 - value { - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "_user_specified_name" - value { - s: "args_3" - } - } - } - } - } - function { - signature { - name: "__inference_Dataset_map_lambda_143" - input_arg { - name: "args_0" - type: DT_INT32 - } - input_arg { - name: "args_1" - type: DT_STRING - } - input_arg { - name: "args_2" - type: DT_INT32 - } - input_arg { - name: "args_3" - type: DT_STRING - } - output_arg { - name: "identity" - type: DT_INT32 - } - output_arg { - name: "identity_1" - type: DT_STRING - } - output_arg { - name: "identity_2" - type: DT_INT32 - } - output_arg { - name: "identity_3" - type: DT_STRING - } - } - node_def { - name: "strided_slice/stack" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } - experimental_debug_info { - original_node_names: "strided_slice/stack" - } - } - node_def { - name: "strided_slice/stack_1" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 512 - } - } - } - experimental_debug_info { - original_node_names: "strided_slice/stack_1" - } - } - node_def { - name: "strided_slice/stack_2" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } - experimental_debug_info { - original_node_names: "strided_slice/stack_2" - } - } - node_def { - name: "strided_slice" - op: "StridedSlice" - input: "args_0" - input: "strided_slice/stack:output:0" - input: "strided_slice/stack_1:output:0" - input: "strided_slice/stack_2:output:0" - attr { - key: "Index" - value { - type: DT_INT32 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - } - } - } - attr { - key: "begin_mask" - value { - i: 1 - } - } - attr { - key: "ellipsis_mask" - value { - i: 0 - } - } - attr { - key: "end_mask" - value { - i: 0 - } - } - attr { - key: "new_axis_mask" - value { - i: 0 - } - } - attr { - key: "shrink_axis_mask" - value { - i: 0 - } - } - experimental_debug_info { - original_node_names: "strided_slice" - } - } - node_def { - name: "strided_slice_1/stack" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } - experimental_debug_info { - original_node_names: "strided_slice_1/stack" - } - } - node_def { - name: "strided_slice_1/stack_1" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 512 - } - } - } - experimental_debug_info { - original_node_names: "strided_slice_1/stack_1" - } - } - node_def { - name: "strided_slice_1/stack_2" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } - experimental_debug_info { - original_node_names: "strided_slice_1/stack_2" - } - } - node_def { - name: "strided_slice_1" - op: "StridedSlice" - input: "args_2" - input: "strided_slice_1/stack:output:0" - input: "strided_slice_1/stack_1:output:0" - input: "strided_slice_1/stack_2:output:0" - attr { - key: "Index" - value { - type: DT_INT32 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - } - } - } - attr { - key: "begin_mask" - value { - i: 1 - } - } - attr { - key: "ellipsis_mask" - value { - i: 0 - } - } - attr { - key: "end_mask" - value { - i: 0 - } - } - attr { - key: "new_axis_mask" - value { - i: 0 - } - } - attr { - key: "shrink_axis_mask" - value { - i: 0 - } - } - experimental_debug_info { - original_node_names: "strided_slice_1" - } - } - node_def { - name: "Identity" - op: "Identity" - input: "strided_slice:output:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - } - } - } - experimental_debug_info { - original_node_names: "Identity" - } - } - node_def { - name: "Identity_1" - op: "Identity" - input: "args_1" - attr { - key: "T" - value { - type: DT_STRING - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - experimental_debug_info { - original_node_names: "Identity_1" - } - } - node_def { - name: "Identity_2" - op: "Identity" - input: "strided_slice_1:output:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - } - } - } - experimental_debug_info { - original_node_names: "Identity_2" - } - } - node_def { - name: "Identity_3" - op: "Identity" - input: "args_3" - attr { - key: "T" - value { - type: DT_STRING - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - experimental_debug_info { - original_node_names: "Identity_3" - } - } - ret { - key: "identity" - value: "Identity:output:0" - } - ret { - key: "identity_1" - value: "Identity_1:output:0" - } - ret { - key: "identity_2" - value: "Identity_2:output:0" - } - ret { - key: "identity_3" - value: "Identity_3:output:0" - } - attr { - key: "_input_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - shape { - } - shape { - dim { - size: -1 - } - } - shape { - } - } - } - } - attr { - key: "_tf_data_function" - value { - b: true - } - } - arg_attr { - key: 0 - value { - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - } - } - } - attr { - key: "_user_specified_name" - value { - s: "args_0" - } - } - } - } - arg_attr { - key: 1 - value { - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "_user_specified_name" - value { - s: "args_1" - } - } - } - } - arg_attr { - key: 2 - value { - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - } - } - } - attr { - key: "_user_specified_name" - value { - s: "args_2" - } - } - } - } - arg_attr { - key: 3 - value { - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "_user_specified_name" - value { - s: "args_3" - } - } - } - } - } - function { - signature { - name: "__inference_Dataset_map_lambda_123" - input_arg { - name: "args_0" - type: DT_INT32 - } - input_arg { - name: "args_1" - type: DT_STRING - } - input_arg { - name: "args_2" - type: DT_INT32 - } - input_arg { - name: "args_3" - type: DT_STRING - } - output_arg { - name: "identity" - type: DT_INT32 - } - output_arg { - name: "identity_1" - type: DT_STRING - } - output_arg { - name: "identity_2" - type: DT_INT32 - } - output_arg { - name: "identity_3" - type: DT_STRING - } - } - node_def { - name: "strided_slice/stack" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } - experimental_debug_info { - original_node_names: "strided_slice/stack" - } - } - node_def { - name: "strided_slice/stack_1" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 511 - } - } - } - experimental_debug_info { - original_node_names: "strided_slice/stack_1" - } - } - node_def { - name: "strided_slice/stack_2" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } - experimental_debug_info { - original_node_names: "strided_slice/stack_2" - } - } - node_def { - name: "strided_slice" - op: "StridedSlice" - input: "args_0" - input: "strided_slice/stack:output:0" - input: "strided_slice/stack_1:output:0" - input: "strided_slice/stack_2:output:0" - attr { - key: "Index" - value { - type: DT_INT32 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - } - } - } - attr { - key: "begin_mask" - value { - i: 1 - } - } - attr { - key: "ellipsis_mask" - value { - i: 0 - } - } - attr { - key: "end_mask" - value { - i: 0 - } - } - attr { - key: "new_axis_mask" - value { - i: 0 - } - } - attr { - key: "shrink_axis_mask" - value { - i: 0 - } - } - experimental_debug_info { - original_node_names: "strided_slice" - } - } - node_def { - name: "concat/values_1" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } - experimental_debug_info { - original_node_names: "concat/values_1" - } - } - node_def { - name: "concat/axis" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } - experimental_debug_info { - original_node_names: "concat/axis" - } - } - node_def { - name: "concat" - op: "ConcatV2" - input: "strided_slice:output:0" - input: "concat/values_1:output:0" - input: "concat/axis:output:0" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - } - } - } - experimental_debug_info { - original_node_names: "concat" - } - } - node_def { - name: "strided_slice_1/stack" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } - experimental_debug_info { - original_node_names: "strided_slice_1/stack" - } - } - node_def { - name: "strided_slice_1/stack_1" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 511 - } - } - } - experimental_debug_info { - original_node_names: "strided_slice_1/stack_1" - } - } - node_def { - name: "strided_slice_1/stack_2" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } - experimental_debug_info { - original_node_names: "strided_slice_1/stack_2" - } - } - node_def { - name: "strided_slice_1" - op: "StridedSlice" - input: "args_2" - input: "strided_slice_1/stack:output:0" - input: "strided_slice_1/stack_1:output:0" - input: "strided_slice_1/stack_2:output:0" - attr { - key: "Index" - value { - type: DT_INT32 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - } - } - } - attr { - key: "begin_mask" - value { - i: 1 - } - } - attr { - key: "ellipsis_mask" - value { - i: 0 - } - } - attr { - key: "end_mask" - value { - i: 0 - } - } - attr { - key: "new_axis_mask" - value { - i: 0 - } - } - attr { - key: "shrink_axis_mask" - value { - i: 0 - } - } - experimental_debug_info { - original_node_names: "strided_slice_1" - } - } - node_def { - name: "concat_1/values_1" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } - experimental_debug_info { - original_node_names: "concat_1/values_1" - } - } - node_def { - name: "concat_1/axis" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } - experimental_debug_info { - original_node_names: "concat_1/axis" - } - } - node_def { - name: "concat_1" - op: "ConcatV2" - input: "strided_slice_1:output:0" - input: "concat_1/values_1:output:0" - input: "concat_1/axis:output:0" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - } - } - } - experimental_debug_info { - original_node_names: "concat_1" - } - } - node_def { - name: "Identity" - op: "Identity" - input: "concat:output:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - } - } - } - experimental_debug_info { - original_node_names: "Identity" - } - } - node_def { - name: "Identity_1" - op: "Identity" - input: "args_1" - attr { - key: "T" - value { - type: DT_STRING - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - experimental_debug_info { - original_node_names: "Identity_1" - } - } - node_def { - name: "Identity_2" - op: "Identity" - input: "concat_1:output:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - } - } - } - experimental_debug_info { - original_node_names: "Identity_2" - } - } - node_def { - name: "Identity_3" - op: "Identity" - input: "args_3" - attr { - key: "T" - value { - type: DT_STRING - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - experimental_debug_info { - original_node_names: "Identity_3" - } - } - ret { - key: "identity" - value: "Identity:output:0" - } - ret { - key: "identity_1" - value: "Identity_1:output:0" - } - ret { - key: "identity_2" - value: "Identity_2:output:0" - } - ret { - key: "identity_3" - value: "Identity_3:output:0" - } - attr { - key: "_input_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - shape { - } - shape { - dim { - size: -1 - } - } - shape { - } - } - } - } - attr { - key: "_tf_data_function" - value { - b: true - } - } - arg_attr { - key: 0 - value { - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - } - } - } - attr { - key: "_user_specified_name" - value { - s: "args_0" - } - } - } - } - arg_attr { - key: 1 - value { - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "_user_specified_name" - value { - s: "args_1" - } - } - } - } - arg_attr { - key: 2 - value { - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - } - } - } - attr { - key: "_user_specified_name" - value { - s: "args_2" - } - } - } - } - arg_attr { - key: 3 - value { - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "_user_specified_name" - value { - s: "args_3" - } - } - } - } - } - function { - signature { - name: "__inference_Dataset_flat_map_read_one_file_38" - input_arg { - name: "args_0" - type: DT_STRING - } - output_arg { - name: "identity" - type: DT_VARIANT - } - is_stateful: true - control_output: "TFRecordDataset" - } - node_def { - name: "compression_type" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - } - string_val: "" - } - } - } - experimental_debug_info { - original_node_names: "compression_type" - } - } - node_def { - name: "buffer_size" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: 262144 - } - } - } - experimental_debug_info { - original_node_names: "buffer_size" - } - } - node_def { - name: "TFRecordDataset" - op: "TFRecordDataset" - input: "args_0" - input: "compression_type:output:0" - input: "buffer_size:output:0" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - experimental_debug_info { - original_node_names: "TFRecordDataset" - } - } - node_def { - name: "Identity" - op: "Identity" - input: "TFRecordDataset:handle:0" - input: "^TFRecordDataset" - attr { - key: "T" - value { - type: DT_VARIANT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - experimental_debug_info { - original_node_names: "Identity" - } - } - ret { - key: "identity" - value: "Identity:output:0" - } - attr { - key: "_input_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "_tf_data_function" - value { - b: true - } - } - control_ret { - key: "TFRecordDataset" - value: "TFRecordDataset" - } - arg_attr { - key: 0 - value { - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "_user_specified_name" - value { - s: "args_0" - } - } - } - } - } - function { - signature { - name: "__inference_Dataset_map__rename_84" - input_arg { - name: "args_0" - type: DT_INT32 - } - input_arg { - name: "args_1" - type: DT_STRING - } - input_arg { - name: "args_2" - type: DT_INT32 - } - input_arg { - name: "args_3" - type: DT_STRING - } - output_arg { - name: "identity" - type: DT_INT32 - } - output_arg { - name: "identity_1" - type: DT_STRING - } - output_arg { - name: "identity_2" - type: DT_INT32 - } - output_arg { - name: "identity_3" - type: DT_STRING - } - } - node_def { - name: "Identity" - op: "Identity" - input: "args_0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - } - } - } - experimental_debug_info { - original_node_names: "Identity" - } - } - node_def { - name: "Identity_1" - op: "Identity" - input: "args_1" - attr { - key: "T" - value { - type: DT_STRING - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - experimental_debug_info { - original_node_names: "Identity_1" - } - } - node_def { - name: "Identity_2" - op: "Identity" - input: "args_2" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - } - } - } - experimental_debug_info { - original_node_names: "Identity_2" - } - } - node_def { - name: "Identity_3" - op: "Identity" - input: "args_3" - attr { - key: "T" - value { - type: DT_STRING - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - experimental_debug_info { - original_node_names: "Identity_3" - } - } - ret { - key: "identity" - value: "Identity:output:0" - } - ret { - key: "identity_1" - value: "Identity_1:output:0" - } - ret { - key: "identity_2" - value: "Identity_2:output:0" - } - ret { - key: "identity_3" - value: "Identity_3:output:0" - } - attr { - key: "_input_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - shape { - } - shape { - dim { - size: -1 - } - } - shape { - } - } - } - } - attr { - key: "_tf_data_function" - value { - b: true - } - } - arg_attr { - key: 0 - value { - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - } - } - } - attr { - key: "_user_specified_name" - value { - s: "args_0" - } - } - } - } - arg_attr { - key: 1 - value { - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "_user_specified_name" - value { - s: "args_1" - } - } - } - } - arg_attr { - key: 2 - value { - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - } - } - } - attr { - key: "_user_specified_name" - value { - s: "args_2" - } - } - } - } - arg_attr { - key: 3 - value { - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "_user_specified_name" - value { - s: "args_3" - } - } - } - } - } - function { - signature { - name: "__inference_Dataset_map_lambda_175" - input_arg { - name: "args_0" - type: DT_INT32 - } - input_arg { - name: "args_1" - type: DT_INT32 - } - output_arg { - name: "identity" - type: DT_INT32 - } - output_arg { - name: "identity_1" - type: DT_INT32 - } - } - node_def { - name: "strided_slice/stack" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } - experimental_debug_info { - original_node_names: "strided_slice/stack" - } - } - node_def { - name: "strided_slice/stack_1" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 512 - } - } - } - experimental_debug_info { - original_node_names: "strided_slice/stack_1" - } - } - node_def { - name: "strided_slice/stack_2" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } - experimental_debug_info { - original_node_names: "strided_slice/stack_2" - } - } - node_def { - name: "strided_slice" - op: "StridedSlice" - input: "args_0" - input: "strided_slice/stack:output:0" - input: "strided_slice/stack_1:output:0" - input: "strided_slice/stack_2:output:0" - attr { - key: "Index" - value { - type: DT_INT32 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - } - } - } - attr { - key: "begin_mask" - value { - i: 1 - } - } - attr { - key: "ellipsis_mask" - value { - i: 0 - } - } - attr { - key: "end_mask" - value { - i: 0 - } - } - attr { - key: "new_axis_mask" - value { - i: 0 - } - } - attr { - key: "shrink_axis_mask" - value { - i: 0 - } - } - experimental_debug_info { - original_node_names: "strided_slice" - } - } - node_def { - name: "strided_slice_1/stack" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } - experimental_debug_info { - original_node_names: "strided_slice_1/stack" - } - } - node_def { - name: "strided_slice_1/stack_1" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 512 - } - } - } - experimental_debug_info { - original_node_names: "strided_slice_1/stack_1" - } - } - node_def { - name: "strided_slice_1/stack_2" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } - experimental_debug_info { - original_node_names: "strided_slice_1/stack_2" - } - } - node_def { - name: "strided_slice_1" - op: "StridedSlice" - input: "args_1" - input: "strided_slice_1/stack:output:0" - input: "strided_slice_1/stack_1:output:0" - input: "strided_slice_1/stack_2:output:0" - attr { - key: "Index" - value { - type: DT_INT32 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - } - } - } - attr { - key: "begin_mask" - value { - i: 1 - } - } - attr { - key: "ellipsis_mask" - value { - i: 0 - } - } - attr { - key: "end_mask" - value { - i: 0 - } - } - attr { - key: "new_axis_mask" - value { - i: 0 - } - } - attr { - key: "shrink_axis_mask" - value { - i: 0 - } - } - experimental_debug_info { - original_node_names: "strided_slice_1" - } - } - node_def { - name: "Identity" - op: "Identity" - input: "strided_slice:output:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - } - } - } - experimental_debug_info { - original_node_names: "Identity" - } - } - node_def { - name: "Identity_1" - op: "Identity" - input: "strided_slice_1:output:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - } - } - } - experimental_debug_info { - original_node_names: "Identity_1" - } - } - ret { - key: "identity" - value: "Identity:output:0" - } - ret { - key: "identity_1" - value: "Identity_1:output:0" - } - attr { - key: "_input_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - shape { - dim { - size: -1 - } - } - } - } - } - attr { - key: "_tf_data_function" - value { - b: true - } - } - arg_attr { - key: 0 - value { - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - } - } - } - attr { - key: "_user_specified_name" - value { - s: "args_0" - } - } - } - } - arg_attr { - key: 1 - value { - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - } - } - } - attr { - key: "_user_specified_name" - value { - s: "args_1" - } - } - } - } - } -} -versions { - producer: 672 - min_consumer: 12 -} diff --git a/t5x-main/t5x/testdata/mtf_tiny_t5/model-info.txt b/t5x-main/t5x/testdata/mtf_tiny_t5/model-info.txt deleted file mode 100644 index 8adee65f8d52de2ea3c7be0d8280229eb3205765..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/testdata/mtf_tiny_t5/model-info.txt +++ /dev/null @@ -1,86 +0,0 @@ -Variable decoder/block_000/layer_000/SelfAttention/k size 4096 slice_size 4096 Shape[d_model=32, heads=128] -Variable decoder/block_000/layer_000/SelfAttention/o size 4096 slice_size 4096 Shape[heads=128, d_model=32] -Variable decoder/block_000/layer_000/SelfAttention/q size 4096 slice_size 4096 Shape[d_model=32, heads=128] -Variable decoder/block_000/layer_000/SelfAttention/relative_attention_bias size 64 slice_size 64 Shape[heads=2, buckets=32] -Variable decoder/block_000/layer_000/SelfAttention/v size 4096 slice_size 4096 Shape[d_model=32, heads=128] -Variable decoder/block_000/layer_000/rms_norm/scale size 32 slice_size 32 Shape[d_model=32] -Variable decoder/block_000/layer_001/EncDecAttention/k size 4096 slice_size 4096 Shape[d_model=32, heads=128] -Variable decoder/block_000/layer_001/EncDecAttention/o size 4096 slice_size 4096 Shape[heads=128, d_model=32] -Variable decoder/block_000/layer_001/EncDecAttention/q size 4096 slice_size 4096 Shape[d_model=32, heads=128] -Variable decoder/block_000/layer_001/EncDecAttention/v size 4096 slice_size 4096 Shape[d_model=32, heads=128] -Variable decoder/block_000/layer_001/rms_norm/scale size 32 slice_size 32 Shape[d_model=32] -Variable decoder/block_000/layer_002/DenseReluDense/wi_0/kernel size 2048 slice_size 2048 Shape[d_model=32, d_ff=64] -Variable decoder/block_000/layer_002/DenseReluDense/wi_1/kernel size 2048 slice_size 2048 Shape[d_model=32, d_ff=64] -Variable decoder/block_000/layer_002/DenseReluDense/wo/kernel size 2048 slice_size 2048 Shape[d_ff=64, d_model=32] -Variable decoder/block_000/layer_002/rms_norm/scale size 32 slice_size 32 Shape[d_model=32] -Variable decoder/block_001/layer_000/SelfAttention/k size 4096 slice_size 4096 Shape[d_model=32, heads=128] -Variable decoder/block_001/layer_000/SelfAttention/o size 4096 slice_size 4096 Shape[heads=128, d_model=32] -Variable decoder/block_001/layer_000/SelfAttention/q size 4096 slice_size 4096 Shape[d_model=32, heads=128] -Variable decoder/block_001/layer_000/SelfAttention/v size 4096 slice_size 4096 Shape[d_model=32, heads=128] -Variable decoder/block_001/layer_000/rms_norm/scale size 32 slice_size 32 Shape[d_model=32] -Variable decoder/block_001/layer_001/EncDecAttention/k size 4096 slice_size 4096 Shape[d_model=32, heads=128] -Variable decoder/block_001/layer_001/EncDecAttention/o size 4096 slice_size 4096 Shape[heads=128, d_model=32] -Variable decoder/block_001/layer_001/EncDecAttention/q size 4096 slice_size 4096 Shape[d_model=32, heads=128] -Variable decoder/block_001/layer_001/EncDecAttention/v size 4096 slice_size 4096 Shape[d_model=32, heads=128] -Variable decoder/block_001/layer_001/rms_norm/scale size 32 slice_size 32 Shape[d_model=32] -Variable decoder/block_001/layer_002/DenseReluDense/wi_0/kernel size 2048 slice_size 2048 Shape[d_model=32, d_ff=64] -Variable decoder/block_001/layer_002/DenseReluDense/wi_1/kernel size 2048 slice_size 2048 Shape[d_model=32, d_ff=64] -Variable decoder/block_001/layer_002/DenseReluDense/wo/kernel size 2048 slice_size 2048 Shape[d_ff=64, d_model=32] -Variable decoder/block_001/layer_002/rms_norm/scale size 32 slice_size 32 Shape[d_model=32] -Variable decoder/logits/kernel size 1028096 slice_size 1028096 Shape[d_model=32, vocab=32128] -Variable decoder/rms_norm/scale size 32 slice_size 32 Shape[d_model=32] -Variable encoder/block_000/layer_000/SelfAttention/k size 4096 slice_size 4096 Shape[d_model=32, heads=128] -Variable encoder/block_000/layer_000/SelfAttention/o size 4096 slice_size 4096 Shape[heads=128, d_model=32] -Variable encoder/block_000/layer_000/SelfAttention/q size 4096 slice_size 4096 Shape[d_model=32, heads=128] -Variable encoder/block_000/layer_000/SelfAttention/relative_attention_bias size 64 slice_size 64 Shape[heads=2, buckets=32] -Variable encoder/block_000/layer_000/SelfAttention/v size 4096 slice_size 4096 Shape[d_model=32, heads=128] -Variable encoder/block_000/layer_000/rms_norm/scale size 32 slice_size 32 Shape[d_model=32] -Variable encoder/block_000/layer_001/DenseReluDense/wi_0/kernel size 2048 slice_size 2048 Shape[d_model=32, d_ff=64] -Variable encoder/block_000/layer_001/DenseReluDense/wi_1/kernel size 2048 slice_size 2048 Shape[d_model=32, d_ff=64] -Variable encoder/block_000/layer_001/DenseReluDense/wo/kernel size 2048 slice_size 2048 Shape[d_ff=64, d_model=32] -Variable encoder/block_000/layer_001/rms_norm/scale size 32 slice_size 32 Shape[d_model=32] -Variable encoder/block_001/layer_000/SelfAttention/k size 4096 slice_size 4096 Shape[d_model=32, heads=128] -Variable encoder/block_001/layer_000/SelfAttention/o size 4096 slice_size 4096 Shape[heads=128, d_model=32] -Variable encoder/block_001/layer_000/SelfAttention/q size 4096 slice_size 4096 Shape[d_model=32, heads=128] -Variable encoder/block_001/layer_000/SelfAttention/v size 4096 slice_size 4096 Shape[d_model=32, heads=128] -Variable encoder/block_001/layer_000/rms_norm/scale size 32 slice_size 32 Shape[d_model=32] -Variable encoder/block_001/layer_001/DenseReluDense/wi_0/kernel size 2048 slice_size 2048 Shape[d_model=32, d_ff=64] -Variable encoder/block_001/layer_001/DenseReluDense/wi_1/kernel size 2048 slice_size 2048 Shape[d_model=32, d_ff=64] -Variable encoder/block_001/layer_001/DenseReluDense/wo/kernel size 2048 slice_size 2048 Shape[d_ff=64, d_model=32] -Variable encoder/block_001/layer_001/rms_norm/scale size 32 slice_size 32 Shape[d_model=32] -Variable encoder/rms_norm/scale size 32 slice_size 32 Shape[d_model=32] -Variable shared/embedding size 1028096 slice_size 1028096 Shape[vocab=32128, d_model=32] -Trainable Variables count: 52 Total size: 2179584 Total slice_size: 2179584 -All Variables count: 104 Total size: 4359168 Total slice_size: 4359168 -Counters: -einsum: 4.99e+10 -einsum_unique: 4.99e+10 -output: 3.07e+07 - output/AddOperation: 2.18e+06 - output/Constant: 1 - output/EinsumOperation: 6.54e+06 - output/ImportOperation: 9.84e+04 - output/MinMaxOperation: 104 - output/ReduceOperation: 104 - output/ReshapeOperation: 9.83e+04 - output/ScalarAddOperation: 2.18e+06 - output/ScalarMultiplyOperation: 4.36e+06 - output/SlicewiseOperation: 8.72e+06 - output/Variable: 4.36e+06 - output/WhileLoopOperation: 2.18e+06 -output_unique: 3.07e+07 - output_unique/AddOperation: 2.18e+06 - output_unique/Constant: 1 - output_unique/EinsumOperation: 6.54e+06 - output_unique/ImportOperation: 9.84e+04 - output_unique/MinMaxOperation: 104 - output_unique/ReduceOperation: 104 - output_unique/ReshapeOperation: 9.83e+04 - output_unique/ScalarAddOperation: 2.18e+06 - output_unique/ScalarMultiplyOperation: 4.36e+06 - output_unique/SlicewiseOperation: 8.72e+06 - output_unique/Variable: 4.36e+06 - output_unique/WhileLoopOperation: 2.18e+06 -variables: 4.36e+06 - variables/trainable: 2.18e+06 - variables/untrainable: 2.18e+06 diff --git a/t5x-main/t5x/testdata/mtf_tiny_t5/model.ckpt-0.data-00000-of-00002 b/t5x-main/t5x/testdata/mtf_tiny_t5/model.ckpt-0.data-00000-of-00002 deleted file mode 100644 index 1b1cb4d44c57c2d7a5122870fa6ac3e62ff7e94e..0000000000000000000000000000000000000000 Binary files a/t5x-main/t5x/testdata/mtf_tiny_t5/model.ckpt-0.data-00000-of-00002 and /dev/null differ diff --git a/t5x-main/t5x/testdata/mtf_tiny_t5/model.ckpt-0.data-00001-of-00002 b/t5x-main/t5x/testdata/mtf_tiny_t5/model.ckpt-0.data-00001-of-00002 deleted file mode 100644 index 95b329719f21f490175a1b20d02a737dd9aa6d40..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/testdata/mtf_tiny_t5/model.ckpt-0.data-00001-of-00002 +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:d2ded324aec12e818cf3a29046fda0e3d85915621d7415bd7c8b80cad03588c9 -size 13077504 diff --git a/t5x-main/t5x/testdata/mtf_tiny_t5/model.ckpt-0.index b/t5x-main/t5x/testdata/mtf_tiny_t5/model.ckpt-0.index deleted file mode 100644 index c7b16afaf57683aa07ece791780f18891a87ec21..0000000000000000000000000000000000000000 Binary files a/t5x-main/t5x/testdata/mtf_tiny_t5/model.ckpt-0.index and /dev/null differ diff --git a/t5x-main/t5x/testdata/mtf_tiny_t5/model.ckpt-0.meta b/t5x-main/t5x/testdata/mtf_tiny_t5/model.ckpt-0.meta deleted file mode 100644 index a04db5a0de37e3a050181801a63785d2985b9a85..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/testdata/mtf_tiny_t5/model.ckpt-0.meta +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:d77a2aaee5b3076bb15729199e4d35ad6b85e25375ad849cd8bacc91ed75da7a -size 2983432 diff --git a/t5x-main/t5x/testdata/mtf_tiny_t5/operative_config.gin b/t5x-main/t5x/testdata/mtf_tiny_t5/operative_config.gin deleted file mode 100644 index 5c1a84fd7db84eed6a142f0d55b47f88c2540c91..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/testdata/mtf_tiny_t5/operative_config.gin +++ /dev/null @@ -1,312 +0,0 @@ -import mesh_tensorflow.optimize -import mesh_tensorflow.transformer.dataset -import mesh_tensorflow.transformer.learning_rate_schedules -import mesh_tensorflow.transformer.t2t_vocabulary -import mesh_tensorflow.transformer.transformer -import mesh_tensorflow.transformer.transformer_layers -import mesh_tensorflow.transformer.utils -import t5.models.mesh_transformer - -# Macros: -# ============================================================================== -d_ff = 64 -d_kv = 64 -d_model = 32 -dropout_rate = 0.0 -MIXTURE_NAME = 'c4_v020_unsupervised' -num_heads = 2 -num_layers = 2 - -# Parameters for adafactor_decay_rate_pow: -# ============================================================================== -adafactor_decay_rate_pow.offset = 0 - -# Parameters for AdafactorOptimizer: -# ============================================================================== -AdafactorOptimizer.beta1 = 0.0 -AdafactorOptimizer.clipping_threshold = 1.0 -AdafactorOptimizer.decay_rate = None -AdafactorOptimizer.epsilon1 = 1e-30 -AdafactorOptimizer.epsilon2 = 0.001 -AdafactorOptimizer.factored = True -AdafactorOptimizer.min_dim_size_to_factor = 128 -AdafactorOptimizer.multiply_by_parameter_scale = True - -# Parameters for Bitransformer: -# ============================================================================== -Bitransformer.shared_embedding = True - -# Parameters for learning_rate_schedules.constant: -# ============================================================================== -learning_rate_schedules.constant.value = 1.0 - -# Parameters for decoder/DenseReluDense: -# ============================================================================== -decoder/DenseReluDense.activation = ['gelu', 'linear'] -decoder/DenseReluDense.dropout_rate = %dropout_rate -decoder/DenseReluDense.hidden_size = %d_ff -decoder/DenseReluDense.use_bias = False - -# Parameters for encoder/DenseReluDense: -# ============================================================================== -encoder/DenseReluDense.activation = ['gelu', 'linear'] -encoder/DenseReluDense.dropout_rate = %dropout_rate -encoder/DenseReluDense.hidden_size = %d_ff -encoder/DenseReluDense.use_bias = False - -# Parameters for enc_dec_attention: -# ============================================================================== -# None. - -# Parameters for enc_dec_attention_bias: -# ============================================================================== -# None. - -# Parameters for decoder/EncDecAttention: -# ============================================================================== -decoder/EncDecAttention.relative_attention_type = None - -# Parameters for get_variable_dtype: -# ============================================================================== -get_variable_dtype.activation_dtype = 'float32' -get_variable_dtype.slice_dtype = 'float32' -get_variable_dtype.master_dtype = 'bfloat16' - -# Parameters for get_vocab_embedding_cls: -# ============================================================================== -# None. - -# Parameters for get_vocabulary: -# ============================================================================== -get_vocabulary.mixture_or_task_name = %MIXTURE_NAME - -# Parameters for decoder/LayerStack: -# ============================================================================== -decoder/LayerStack.dropout_rate = None -decoder/LayerStack.norm_epsilon = None -decoder/LayerStack.recompute_grads = False -decoder/LayerStack.sublayers_final = \ - [@transformer.sublayer_rms_norm, @transformer.sublayer_dropout] -decoder/LayerStack.sublayers_initial = [@transformer.sublayer_dropout] -decoder/LayerStack.sublayers_per_layer = \ - [@transformer.sublayer_rms_norm, - @transformer.sublayer_call_layer, - @transformer.sublayer_dropout, - @transformer.sublayer_residual] - -# Parameters for encoder/LayerStack: -# ============================================================================== -encoder/LayerStack.dropout_rate = None -encoder/LayerStack.norm_epsilon = None -encoder/LayerStack.recompute_grads = False -encoder/LayerStack.sublayers_final = \ - [@transformer.sublayer_rms_norm, @transformer.sublayer_dropout] -encoder/LayerStack.sublayers_initial = [@transformer.sublayer_dropout] -encoder/LayerStack.sublayers_per_layer = \ - [@transformer.sublayer_rms_norm, - @transformer.sublayer_call_layer, - @transformer.sublayer_dropout, - @transformer.sublayer_residual] - -# Parameters for linear_decay: -# ============================================================================== -linear_decay.steps_or_fraction = 0.1 - -# Parameters for make_bitransformer: -# ============================================================================== -make_bitransformer.decoder_name = 'decoder' -make_bitransformer.encoder_name = 'encoder' - -# Parameters for decoder/make_layer_stack: -# ============================================================================== -decoder/make_layer_stack.block_scope = True -decoder/make_layer_stack.layers = \ - [@mesh_tensorflow.transformer.transformer_layers.SelfAttention, - @mesh_tensorflow.transformer.transformer_layers.EncDecAttention, - @mesh_tensorflow.transformer.transformer_layers.DenseReluDense] -decoder/make_layer_stack.num_layers = %num_layers - -# Parameters for encoder/make_layer_stack: -# ============================================================================== -encoder/make_layer_stack.block_scope = True -encoder/make_layer_stack.layers = \ - [@mesh_tensorflow.transformer.transformer_layers.SelfAttention, - @mesh_tensorflow.transformer.transformer_layers.DenseReluDense] -encoder/make_layer_stack.num_layers = %num_layers - -# Parameters for mesh_train_dataset_fn: -# ============================================================================== -mesh_train_dataset_fn.mixture_or_task_name = %MIXTURE_NAME -mesh_train_dataset_fn.pack = True -mesh_train_dataset_fn.seed = None -mesh_train_dataset_fn.use_cached = True - -# Parameters for pack_dataset: -# ============================================================================== -pack_dataset.use_custom_ops = True - -# Parameters for pack_or_pad: -# ============================================================================== -# None. - -# Parameters for learning_rate_schedules.product_learning_rate: -# ============================================================================== -learning_rate_schedules.product_learning_rate.offset = 0 - -# Parameters for rewrite_stack_variables: -# ============================================================================== -rewrite_stack_variables.max_combined_variable_size = 536870912 - -# Parameters for run: -# ============================================================================== -run.autostack = True -run.batch_size = ('tokens_per_batch', 16384) -run.dataset_split = 'train' -run.ensemble_inputs = None -run.eval_checkpoint_step = None -run.eval_dataset_fn = None -run.eval_summary_dir = None -run.export_checkpoint_step = None -run.export_path = '' -run.init_checkpoint = None -run.iterations_per_loop = 100 -run.keep_checkpoint_max = None -run.layout_rules = \ - 'ensemble:ensemble,batch:batch,d_ff:model,heads:model,vocab:model,experts:batch' -run.learning_rate_schedule = \ - [@learning_rate_schedules.truncated_rsqrt, - @learning_rate_schedules.linear_decay, - @learning_rate_schedules.constant] -run.mesh_devices = None -run.mesh_shape = [] -run.mode = 'train' -run.model_type = 'bitransformer' -run.optimizer = @optimize.AdafactorOptimizer -run.output_eval_examples = True -run.perplexity_eval_steps = 100 -run.predict_fn = None -run.save_checkpoints_steps = 5000 -run.sequence_length = {'inputs': 512, 'targets': 512} -run.skip_seen_data = False -run.total_run_steps = None -run.train_dataset_fn = @t5.models.mesh_transformer.mesh_train_dataset_fn -run.train_steps = 100 -run.variable_filter = None - -# Parameters for decoder/SelfAttention: -# ============================================================================== -decoder/SelfAttention.attention_func = None -decoder/SelfAttention.attention_kwargs = None -decoder/SelfAttention.combine_dims = True -decoder/SelfAttention.dropout_rate = %dropout_rate -decoder/SelfAttention.fold_scaling_into_initializer = True -decoder/SelfAttention.keep_query_heads_dims = False -decoder/SelfAttention.key_value_size = %d_kv -decoder/SelfAttention.num_heads = %num_heads -decoder/SelfAttention.num_memory_heads = 0 -decoder/SelfAttention.relative_attention_num_buckets = 32 -decoder/SelfAttention.relative_attention_type = 'bias_shared' -decoder/SelfAttention.shared_kv = False - -# Parameters for encoder/SelfAttention: -# ============================================================================== -encoder/SelfAttention.attention_func = None -encoder/SelfAttention.attention_kwargs = None -encoder/SelfAttention.combine_dims = True -encoder/SelfAttention.dropout_rate = %dropout_rate -encoder/SelfAttention.fold_scaling_into_initializer = True -encoder/SelfAttention.keep_query_heads_dims = False -encoder/SelfAttention.key_value_size = %d_kv -encoder/SelfAttention.num_heads = %num_heads -encoder/SelfAttention.num_memory_heads = 0 -encoder/SelfAttention.relative_attention_num_buckets = 32 -encoder/SelfAttention.relative_attention_type = 'bias_shared' -encoder/SelfAttention.shared_kv = False - -# Parameters for serialize_num_microbatches: -# ============================================================================== -serialize_num_microbatches.tokens_per_microbatch_per_replica = 4096 - -# Parameters for sublayer_call_layer: -# ============================================================================== -# None. - -# Parameters for sublayer_dropout: -# ============================================================================== -sublayer_dropout.dropout_rate = %dropout_rate - -# Parameters for sublayer_mask_padding: -# ============================================================================== -# None. - -# Parameters for sublayer_residual: -# ============================================================================== -# None. - -# Parameters for sublayer_rms_norm: -# ============================================================================== -sublayer_rms_norm.epsilon = 1e-06 -sublayer_rms_norm.name = 'rms_norm' - -# Parameters for tpu_estimator_model_fn: -# ============================================================================== -tpu_estimator_model_fn.hierarchical_tiling_spec = None -tpu_estimator_model_fn.init_variable_filter = '' -tpu_estimator_model_fn.outer_batch_size = 1 -tpu_estimator_model_fn.tpu_summaries = False - -# Parameters for truncated_rsqrt: -# ============================================================================== -truncated_rsqrt.warmup_steps = 10000.0 - -# Parameters for unit_scaling_convention: -# ============================================================================== -unit_scaling_convention.value = False - -# Parameters for decoder/Unitransformer: -# ============================================================================== -decoder/Unitransformer.d_model = %d_model -decoder/Unitransformer.ensemble = None -decoder/Unitransformer.input_full_attention = False -decoder/Unitransformer.label_smoothing = 0.0 -decoder/Unitransformer.loss_denominator = None -decoder/Unitransformer.loss_fn = None -decoder/Unitransformer.loss_on_targets_only = False -decoder/Unitransformer.max_length = 512 -decoder/Unitransformer.positional_embedding = False -decoder/Unitransformer.shared_embedding_and_softmax_weights = False -decoder/Unitransformer.sinusoid_positional_embedding = False -decoder/Unitransformer.token_dropout_rate = 0.0 -decoder/Unitransformer.vocab_divisor = 128 -decoder/Unitransformer.z_loss = 0.0001 - -# Parameters for encoder/Unitransformer: -# ============================================================================== -encoder/Unitransformer.d_model = %d_model -encoder/Unitransformer.ensemble = None -encoder/Unitransformer.input_full_attention = False -encoder/Unitransformer.label_smoothing = 0.0 -encoder/Unitransformer.loss_denominator = None -encoder/Unitransformer.loss_fn = None -encoder/Unitransformer.loss_on_targets_only = False -encoder/Unitransformer.max_length = 512 -encoder/Unitransformer.positional_embedding = False -encoder/Unitransformer.shared_embedding_and_softmax_weights = False -encoder/Unitransformer.sinusoid_positional_embedding = False -encoder/Unitransformer.token_dropout_rate = 0.0 -encoder/Unitransformer.vocab_divisor = 128 -encoder/Unitransformer.z_loss = 0.0001 - -# Parameters for unsupervised: -# ============================================================================== -unsupervised.preprocessors = None - -# Parameters for VarianceScalingInitializer: -# ============================================================================== -VarianceScalingInitializer.distribution = 'normal' -VarianceScalingInitializer.mode = 'fan_in' -VarianceScalingInitializer.scale = 1.0 - -# Parameters for VocabEmbedding: -# ============================================================================== -VocabEmbedding.scale_variable_like_classifier_weights = False diff --git a/t5x-main/t5x/testdata/pinned_ckpt_dir/PINNED b/t5x-main/t5x/testdata/pinned_ckpt_dir/PINNED deleted file mode 100644 index 56a6051ca2b02b04ef92d5150c9ef600403cb1de..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/testdata/pinned_ckpt_dir/PINNED +++ /dev/null @@ -1 +0,0 @@ -1 \ No newline at end of file diff --git a/t5x-main/t5x/testdata/test_t5_tiny.checkpoint_0 b/t5x-main/t5x/testdata/test_t5_tiny.checkpoint_0 deleted file mode 100644 index 057d4daa14559d01591049d667c4943c0a6a40b4..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/testdata/test_t5_tiny.checkpoint_0 +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:2cd4d2f4a0466ba7d302127f1cde54628c0f9413999040b719f5178363a3eca9 -size 4154134 diff --git a/t5x-main/t5x/testdata/tiny_orbax/1/_optimizer.state.param_states.bias/0 b/t5x-main/t5x/testdata/tiny_orbax/1/_optimizer.state.param_states.bias/0 deleted file mode 100644 index a456e60f98c70763c791e0171116001df557d368..0000000000000000000000000000000000000000 Binary files a/t5x-main/t5x/testdata/tiny_orbax/1/_optimizer.state.param_states.bias/0 and /dev/null differ diff --git a/t5x-main/t5x/testdata/tiny_orbax/1/_optimizer.state.param_states.kernel/0.0 b/t5x-main/t5x/testdata/tiny_orbax/1/_optimizer.state.param_states.kernel/0.0 deleted file mode 100644 index 207b68275d401f7e5180d0c9a34f183769ecc8e8..0000000000000000000000000000000000000000 Binary files a/t5x-main/t5x/testdata/tiny_orbax/1/_optimizer.state.param_states.kernel/0.0 and /dev/null differ diff --git a/t5x-main/t5x/testdata/tiny_orbax/1/_optimizer.state.param_states.step/0 b/t5x-main/t5x/testdata/tiny_orbax/1/_optimizer.state.param_states.step/0 deleted file mode 100644 index 437babcba202ff9cfb2ae79f844c2cf025332a0a..0000000000000000000000000000000000000000 Binary files a/t5x-main/t5x/testdata/tiny_orbax/1/_optimizer.state.param_states.step/0 and /dev/null differ diff --git a/t5x-main/t5x/testdata/tiny_orbax/1/_optimizer.state.step/0 b/t5x-main/t5x/testdata/tiny_orbax/1/_optimizer.state.step/0 deleted file mode 100644 index 437babcba202ff9cfb2ae79f844c2cf025332a0a..0000000000000000000000000000000000000000 Binary files a/t5x-main/t5x/testdata/tiny_orbax/1/_optimizer.state.step/0 and /dev/null differ diff --git a/t5x-main/t5x/testdata/tiny_orbax/1/checkpoint b/t5x-main/t5x/testdata/tiny_orbax/1/checkpoint deleted file mode 100644 index 479f5c61b01e478437d6e4512c8398c9f6523855..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/testdata/tiny_orbax/1/checkpoint +++ /dev/null @@ -1 +0,0 @@ -_optimizerstateparam_statesbias0PLACEHOLDER://_optimizer.state.param_states.biaskernel2PLACEHOLDER://_optimizer.state.param_states.kernelstep0PLACEHOLDER://_optimizer.state.param_states.stepstep#PLACEHOLDER://_optimizer.state.steptargetflax_mutablesflax_mutables_axesparams_axes \ No newline at end of file diff --git a/t5x-main/t5x/testdata/tiny_t5/checkpoint_1/checkpoint b/t5x-main/t5x/testdata/tiny_t5/checkpoint_1/checkpoint deleted file mode 100644 index 5ddf7bb965f98df3ef31af740a397d4be1b06a6e..0000000000000000000000000000000000000000 Binary files a/t5x-main/t5x/testdata/tiny_t5/checkpoint_1/checkpoint and /dev/null differ diff --git a/t5x-main/t5x/testdata/tiny_t5/checkpoint_1/state.param_states.bias/0 b/t5x-main/t5x/testdata/tiny_t5/checkpoint_1/state.param_states.bias/0 deleted file mode 100644 index 9f6100a0c356d11fb39b17061ffb8ee19a5184d3..0000000000000000000000000000000000000000 Binary files a/t5x-main/t5x/testdata/tiny_t5/checkpoint_1/state.param_states.bias/0 and /dev/null differ diff --git a/t5x-main/t5x/testdata/tiny_t5/checkpoint_1/state.param_states.kernel/0.0 b/t5x-main/t5x/testdata/tiny_t5/checkpoint_1/state.param_states.kernel/0.0 deleted file mode 100644 index c302835877f986d4b30ed1927128b93438a0070d..0000000000000000000000000000000000000000 Binary files a/t5x-main/t5x/testdata/tiny_t5/checkpoint_1/state.param_states.kernel/0.0 and /dev/null differ diff --git a/t5x-main/t5x/testdata/tiny_t5/checkpoint_1/state.param_states.step/0 b/t5x-main/t5x/testdata/tiny_t5/checkpoint_1/state.param_states.step/0 deleted file mode 100644 index 872eef28e9039ee55aff5824225abaf59c64a9e2..0000000000000000000000000000000000000000 Binary files a/t5x-main/t5x/testdata/tiny_t5/checkpoint_1/state.param_states.step/0 and /dev/null differ diff --git a/t5x-main/t5x/train.py b/t5x-main/t5x/train.py deleted file mode 100644 index baa989aad80a6d4902a3c70d7bf7df1044bf78fe..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/train.py +++ /dev/null @@ -1,965 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -r"""Script to pretrain or finetune in JAX using a SeqIO pipeline. - -""" - - -# pylint: disable=g-import-not-at-top - -import functools -import gc -import math -import os -import time -from typing import Callable, Dict, Mapping, Optional, Sequence, Tuple, Type - -# Set Linen to add profiling information when constructing Modules. -# Must be set before flax imports. -os.environ['FLAX_PROFILE'] = 'true' -from absl import logging -from clu import metric_writers -import jax -from jax import random -from jax.experimental import multihost_utils -import jax.numpy as jnp -import numpy as np -import seqio -from t5x import checkpoints -from t5x import eval as eval_lib -from t5x import models -from t5x import partitioning -from t5x import train_state as train_state_lib -from t5x import trainer as trainer_lib -from t5x import utils -import tensorflow as tf -# pylint:enable=g-import-not-at-top - -# pylint:enable=g-import-not-at-top - -# Automatically search for gin files relative to the T5X package. -_DEFAULT_GIN_SEARCH_PATHS = [ - os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -] -P = partitioning.PartitionSpec -# Special key that used to distinguish train metrics. -TRAIN_METRIC_KEY = 'train' -# String keys that is acceptable from config. -_ACTION_KEYS = frozenset(trainer_lib.ActionMode.__members__.keys()) -_IMPORT_TIME = time.time() - - -def run_actions( - mode: trainer_lib.ActionMode, - actions: trainer_lib.ActionMapType, - train_state: train_state_lib.TrainState, - metrics_by_task: Mapping[str, trainer_lib.MetricValueMapType], -) -> bool: - """Invokes all actions on the given mode on host 0, then broadcasts to all. - - Args: - mode: The mode to run the actions. e.g., if mode is `train`, only actions - configured to run with `train` mode will be invoked. - actions: A mapping of actions that runs after train, eval or infer_eval, to - inspect the model and perform useful operations, e.g., early stopping. - train_state: The current train_state of the trainer. - metrics_by_task: A map of metrics keyed by task name. - - Returns: - A bool indicating whether training should be halted. - - Raises: - RuntimeError: When the metrics processed on host 0 is None. - """ - stop_training = False - if jax.process_index() == 0: - if not metrics_by_task: - raise RuntimeError('Metric is unexpectedly empty on process 0') - for action in actions.get(mode, []): - stop_training |= action.run(train_state, metrics_by_task=metrics_by_task) - # Broadcast result from host 0 to others. - return bool(multihost_utils.broadcast_one_to_all(jnp.array(stop_training))) - - -def train( - *, - model: models.BaseModel, - train_dataset_cfg: utils.DatasetConfig, - train_eval_dataset_cfg: Optional[utils.DatasetConfig], - infer_eval_dataset_cfg: Optional[utils.DatasetConfig], - checkpoint_cfg: utils.CheckpointConfig, - partitioner: partitioning.BasePartitioner, - trainer_cls: trainer_lib.BaseTrainerConstructor, - model_dir: str, - total_steps: int, - eval_steps: int, - eval_period: int, - relative_steps: Optional[int] = None, - stats_period: Optional[int] = None, - random_seed: Optional[int], - use_hardware_rng: bool = False, - summarize_config_fn: Callable[ - [str, metric_writers.MetricWriter, int], None - ], - inference_evaluator_cls: utils.EvaluatorConstructor = seqio.Evaluator, - get_dataset_fn: utils.GetDatasetCallable = utils.get_dataset, - concurrent_metrics: bool = True, - actions: Optional[Mapping[str, Sequence[trainer_lib.BaseAction]]] = None, - train_eval_get_dataset_fn: utils.GetEvalDatasetCallable = utils.get_training_eval_datasets, - run_eval_before_training: bool = False, - train_state_initializer_cls: Type[ - utils.TrainStateInitializer - ] = utils.TrainStateInitializer, - use_orbax: bool = True, - verify_matching_vocabs_fn: Optional[ - Callable[[utils.DatasetConfig, models.BaseTransformerModel], None] - ] = utils.verify_matching_vocabs, - gc_period: int = 0, -) -> Tuple[int, train_state_lib.TrainState]: - """Train function. - - Args: - model: The model object to use for training. - train_dataset_cfg: Specification for the dataset to train with. - train_eval_dataset_cfg: Specification for the dataset to evaluate with using - the train metrics and no inference (e.g., uses teacher forcing). If None, - train eval is disabled. - infer_eval_dataset_cfg: Specification for the dataset to evaluate with using - the inference metrics (e.g., uses sampled decoding). If None, inference - eval is disabled. - checkpoint_cfg: Specification for saving and restoring model parameters and - dataset state to/from checkpoints. - partitioner: Partitioner for model parameters and data across devices. - trainer_cls: An implementation of BaseTrainer. - model_dir: Path of directory to store checkpoints and metric summaries. - total_steps: The step number to stop training after. The number of actual - steps trained in this run will be this number minus the starting step from - the checkpoint. If this is set to the starting step from the checkpoint, - the model will not be compiled for training and training will not be run. - This can be used in conjunction with `run_eval_before_training` to only - evaluate a model. - eval_steps: The number of batches to process for each train-eval loop. - eval_period: The number of train steps between each evaluation (both - train-eval and infer-eval). - relative_steps: The number of train steps to take relative to the current - step loaded from the checkpoint. If this is set, total_steps is ignored. - stats_period: The number of train steps between writing scalar stats. If - None, defaults to eval_period. - random_seed: A random seed to use for dropout and initialization. If None, a - fast, non-deterministic hardware-based RNG is used. - use_hardware_rng: Whether to force using the RngBitGenerator based hardware - rng, which takes seeds and acts similarly to software PRNG in that it - should be seed-deterministic. The new RngBitGenerator custom PRNG system - should be reproducible for a given sharding, but the numbers will change - for different shardings of the same model. - summarize_config_fn: A function that takes in the model directory, a - SummaryWriter, and the step number, and writes a summary of the - inference_evaluator_cls: seqio.Evaluator class to use for inference - evaluation, potentially with bound configuration args. - get_dataset_fn: The callable use to get the train and train-eval datasets - based on the DatasetConfig and shard information. - concurrent_metrics: If True, allow metrics computation and logging to - overlap with training. Will likely result in additional TPU memory usage. - actions: A mapping of actions that runs after train, eval or infer_eval, to - inspect the model and perform useful operations, e.g., early stopping. The - key must have a 1:1 mapping to ActionMode enum. For EVAL actions to - actually work, this requires `concurrent_metrics` to be turned off, since - chaining futures and mutating states concurrently might be error-prone. - train_eval_get_dataset_fn: Optional callable use to get the train-eval - datasets based on the DatasetConfig and shard information. If missing, it - defaults to `utils.get_training_eval_datasets`. - run_eval_before_training: If True, calculate training eval and inference - eval metrics before training begins. - train_state_initializer_cls: t5x.utils.TrainStateInitializer class for - initializing partitioned TrainState from checkpoints or scratch. - use_orbax: if True, uses Orbax for checkpointing. Experimental feature. - verify_matching_vocabs_fn: Function to validate whether the task vocabulary - matches the model vocabulary, if the model is a BaseTransformerModel - instance. Should raise an exception on error. - gc_period: The number of train steps between runs of the garbage collector. - If 0, the garbage collector will run at the normal frequency. # BEGIN - training_eval_average_metrics: Averages the metric over the list of tasks - for training_eval (e.g., {'task_average/accuracy': ['task_a', 'task_b']}). - infer_eval_average_metrics: Averages the metric over the list of tasks for - infer_eval (e.g., {'task_average/accuracy': ['task_a', 'task_b']}). # END - - Returns: - The tuple of (last_step, last_train_state). - """ - logging.info('Process ID: %d', jax.process_index()) - tf.io.gfile.makedirs(model_dir) - - if use_orbax: - logging.info('Checkpointing with Orbax enabled.') - - # Each "epoch" of the training loop should be the min of the eval period, - # checkpoint period or the full training. - # We compute here to ensure that the eval period and checkpoint period are - # divisible by this number, otherwise we fail. - eval_enabled = train_eval_dataset_cfg or infer_eval_dataset_cfg - eval_period = eval_period if eval_enabled else 0 - checkpoint_period = checkpoint_cfg.save.period if checkpoint_cfg.save else 0 - checkpoint_steps = ( - checkpoint_cfg.save.checkpoint_steps if checkpoint_cfg.save else [] - ) - - if use_hardware_rng or random_seed is None: - logging.info( - 'Using fast RngBitGenerator PRNG for initialization and dropout.' - ) - - if random_seed is None: - random_seed = multihost_utils.broadcast_one_to_all(np.int32(time.time())) - logging.info('Random seed not provided, using RNG seed %s', random_seed) - else: - logging.warning( - 'When using hardware RNG with a fixed seed, repeatability is only ' - 'guaranteed for fixed hardware and partitioning schemes and for a ' - 'fixed version of this code and its dependencies.' - ) - utils.set_hardware_rng_ops() - rng = random.PRNGKey(random_seed) - else: - logging.info( - 'Using seed for initialization and dropout RNG: %d', random_seed - ) - rng = random.PRNGKey(random_seed) - - init_rng, trainer_rng = random.split(rng, 2) - - # --------------------------------------------------------------------------- - # Initialize datasets - # --------------------------------------------------------------------------- - - if train_dataset_cfg.seed and not ( - checkpoint_cfg.save and checkpoint_cfg.save.save_dataset - ): - logging.warning( - 'Providing a random seed for the train dataset with ' - '`checkpoint_train_ds=False` is dangerous since each ' - 'preemption/restart will cause the dataset to deterministically replay ' - 'from the beginning.' - ) - - data_layout = partitioner.get_data_layout(train_dataset_cfg.batch_size) - ds_shard_id = data_layout.shard_id - num_ds_shards = data_layout.num_shards - - def _verify_matching_vocabs(cfg: utils.DatasetConfig): - if verify_matching_vocabs_fn and isinstance( - model, models.BaseTransformerModel - ): - verify_matching_vocabs_fn(cfg, model) - - _verify_matching_vocabs(train_dataset_cfg) - - train_iter = get_dataset_fn( - train_dataset_cfg, ds_shard_id, num_ds_shards, model.FEATURE_CONVERTER_CLS - ) - train_iter = utils.prepare_train_iter( - train_iter, - checkpoint_cfg=checkpoint_cfg, - partitioner=partitioner, - data_layout=data_layout, - ) - input_shapes = jax.tree.map( - lambda x: (data_layout.batch_size, *x.shape[1:]), - train_iter.element_spec, - ) - input_types = jax.tree.map(lambda x: x.dtype, train_iter.element_spec) - - if train_eval_dataset_cfg: - _verify_matching_vocabs(train_eval_dataset_cfg) - train_eval_datasets = train_eval_get_dataset_fn( - train_eval_dataset_cfg, - ds_shard_id, - num_ds_shards, - eval_steps, - model.FEATURE_CONVERTER_CLS, - ) # type: Mapping[str, tf.data.Dataset] - if not train_eval_datasets: - logging.warning( - 'No train_eval datasets loaded from config `train_eval_dataset_cfg`: ' - '%s', - train_eval_dataset_cfg, - ) - else: - train_eval_datasets = {} - - # The manner in which parameters are initialized follows this order of - # preference: - # 1. From a T5X checkpoint in `model_dir`, if one exists. - # 2. From a T5X or TF checkpoint specified by `cfg.path`, if set. - # 3. From scratch using `init_fn`. - - # 1. From a T5X checkpoint in `model_dir`, if one exists. - if checkpoint_cfg.restore is not None: - state_transforms_for_restore = [ - functools.partial(fn, is_resuming=True) - for fn in checkpoint_cfg.restore.state_transformation_fns - ] - else: - state_transforms_for_restore = [] - restore_cfgs = [ - utils.RestoreCheckpointConfig( - path=model_dir, - mode='latest', - dtype=checkpoint_cfg.save.dtype if checkpoint_cfg.save else 'float32', - checkpointer_cls=checkpoint_cfg.save.checkpointer_cls - if checkpoint_cfg.save - else checkpoints.Checkpointer, - # Restore dataset state if it is being saved. - restore_dataset=( - checkpoint_cfg.save and checkpoint_cfg.save.save_dataset - ), - state_transformation_fns=state_transforms_for_restore, - ) - ] - # 2. From a checkpoint specified by `checkpoint_cfg.restore.path`, if set. - if checkpoint_cfg.restore: - if checkpoint_cfg.restore.mode == 'all': - raise ValueError( - "Restore checkpoint mode 'all' is not supported in training." - ) - - # TODO(dhgarrette): Split "restore" behavior into separate configurations - # for the initial restoration for a new run, vs resuming a stopped run. - if isinstance(checkpoint_cfg.restore.path, str): - restore_cfgs.append(checkpoint_cfg.restore) - elif not checkpoint_cfg.restore.path: - # `path` is an empty (non-`str`) sequence, so there is nothing to restore. - pass - else: - raise ValueError( - 'Restore checkpoint config may only have a single path in training.' - ) - - init_or_restore_tick = time.time() - train_state_initializer = train_state_initializer_cls( - optimizer_def=model.optimizer_def, - init_fn=model.get_initial_variables, - input_shapes=input_shapes, - input_types=input_types, - partitioner=partitioner, - ) - - # May be None, empty - valid_restore_cfg, restore_paths = ( - utils.get_first_valid_restore_config_and_paths(restore_cfgs) - ) - if len(restore_paths) > 1: - raise ValueError('Multiple restore paths not permitted in training.') - - # Skip initialization if neither save nor restore is requested. - train_state = None - checkpoint_manager = None - if valid_restore_cfg or checkpoint_period or checkpoint_steps: - train_state, checkpoint_manager = ( - utils.create_checkpoint_manager_and_restore( - train_state_initializer, - partitioner, - valid_restore_cfg, - restore_paths[0] if restore_paths else None, - init_rng, - save_checkpoint_cfg=checkpoint_cfg.save, - model_dir=model_dir, - ds_iter=train_iter, - use_orbax=use_orbax, - ) - ) - - # Start warming up the input pipeline in the background. This must happen - # after input pipeline checkpoints were restored. - first_batch_ready = train_iter.peek_async() - - # 3. If no checkpoint to restore, init from scratch. - train_state = train_state or train_state_initializer.from_scratch(init_rng) - train_state_axes = train_state_initializer.train_state_axes - init_or_restore_secs = time.time() - init_or_restore_tick - logging.info( - 'Initialize/restore complete (%.2f seconds).', init_or_restore_secs - ) - - # Log the variable shapes information and write to a file. - log_file = os.path.join(model_dir, 'model-info.txt') - utils.log_model_info( - log_file, train_state_initializer.global_train_state_shape, partitioner - ) - - # Restore step from last checkpoint or set to 0 if training from scratch. - host_step = int(utils.get_local_data(train_state.step)) # pytype: disable=attribute-error - - if relative_steps: - total_steps = host_step + relative_steps - - if eval_period or checkpoint_period or gc_period: - steps_per_epoch = min( - eval_period or np.inf, checkpoint_period or np.inf, gc_period or np.inf - ) - else: - steps_per_epoch = total_steps - stats_period = stats_period or steps_per_epoch - if ( - eval_period - and eval_period % steps_per_epoch - or checkpoint_period - and checkpoint_period % steps_per_epoch - or gc_period - and gc_period % steps_per_epoch - ): - raise ValueError( - f'Checkpoint period ({checkpoint_period}), eval ' - f'period ({eval_period}), and GC period ({gc_period}) must all be ' - 'multiples of each other.' - ) - - # --------------------------------------------------------------------------- - # Trainer - # --------------------------------------------------------------------------- - - trainer: trainer_lib.BaseTrainer = trainer_cls( # pytype: disable=wrong-arg-types - model=model, - train_state=train_state, - partitioner=partitioner, - train_state_axes=train_state_axes, - eval_names=train_eval_datasets.keys(), - summary_dir=model_dir, - rng=trainer_rng, - ) - del train_state - - train_metrics = trainer.train_metrics_manager - summarize_config_fn(model_dir, train_metrics.summary_writer, host_step) - - train_metrics.write_scalar( - 'timing/init_or_restore_seconds', init_or_restore_secs, host_step - ) - - # ---------------------------------------------------------------------------- - # SeqIO (inference-based) evaluation setup - # ---------------------------------------------------------------------------- - # Init evaluator to set up cached datasets - evaluator = None - if infer_eval_dataset_cfg is not None: - evaluator = eval_lib.InferenceEvaluator( - infer_eval_dataset_cfg=infer_eval_dataset_cfg, - inference_evaluator_cls=inference_evaluator_cls, - model=model, - partitioner=partitioner, - log_dir=model_dir, - verify_matching_vocabs_fn=verify_matching_vocabs_fn, - ) - if not evaluator.eval_tasks: - # Skip evaluation. - evaluator = None - - if actions is None: - actions = {} - - if set(actions.keys()).difference(_ACTION_KEYS): - raise ValueError( - f'actions keys must be one of {_ACTION_KEYS}, but got : ' - f'{actions.keys()}' - ) - - # Transform the string key into proper ActionMode enum. - actions = {trainer_lib.ActionMode[k]: v for k, v in actions.items()} - - if ( - concurrent_metrics - and actions.get(trainer_lib.ActionMode.INFER_EVAL, None) is not None - ): - logging.warning( - 'Actions for INFER_EVAL will not be triggered when async ' - 'metrics computation is enabled' - ) - if ( - concurrent_metrics - and actions.get(trainer_lib.ActionMode.TRAIN, None) is not None - ): - logging.warning( - 'Actions for TRAIN will not be triggered when async ' - 'metrics computation is enabled' - ) - - # ---------------------------------------------------------------------------- - # Setup Eval Utility Functions - # ---------------------------------------------------------------------------- - - def _run_training_eval(first_run: bool = False): - if first_run: - logging.info('Compiling training eval loop.') - trainer.compile_eval({ # pytype: disable=wrong-arg-types # jax-ndarray - task: utils.get_zeros_batch_like_dataset(ds) - for task, ds in train_eval_datasets.items() - }) - logging.info('Computing training evaluation metrics.') - eval_batch_iters = {} - for task, ds in train_eval_datasets.items(): - if isinstance(ds, tf.data.Dataset): - eval_batch_iters[task] = ds.as_numpy_iterator() - else: - eval_batch_iters[task] = ds - - eval_summaries = trainer.eval(eval_batch_iters) - trainer.stop_training = run_actions( - trainer_lib.ActionMode.TRAIN_EVAL, # pytype: disable=wrong-arg-types # jax-ndarray - actions, - trainer.train_state, - eval_summaries, - ) - - def _run_inference_eval(): - """Run prediction based inference eval.""" - if evaluator is None: - return - logging.info('Running inference evaluation.') - evaluate_tick = time.time() - all_metrics = evaluator.evaluate(trainer.train_state, train_state_axes) - if not concurrent_metrics: - # Ensure metrics are finished being computed. - all_metrics_done = all_metrics.result() or {} - trainer.stop_training = run_actions( - trainer_lib.ActionMode.INFER_EVAL, - actions, - trainer.train_state, - all_metrics_done, - ) - train_metrics.write_scalar( - 'timing/evaluate_seconds', time.time() - evaluate_tick, host_step - ) - - # Optionally run teacher-forcing training eval and SeqIO inference-base eval - # before training. Useful for testing how much a model knows before any - # finetuning. - if run_eval_before_training: - if train_eval_datasets: - logging.info('Running training eval before training.') - _run_training_eval(first_run=True) - if evaluator is not None: - logging.info('Running inference eval before training.') - _run_inference_eval() - - # Save checkpoints before the training loop starts. - if checkpoint_period and checkpoint_manager: - # If not using Orbax, always save checkpoint, otherwise, only save a - # checkpoint if a checkpoint does not already exist for that step. This is - # because Orbax will error out if we try to save a checkpoint that already - # exists. - if not use_orbax or ( - use_orbax - and utils.get_local_data(trainer.train_state.step) - not in checkpoint_manager.all_steps() - ): - logging.info('Saving checkpoint before the training loop starts.') - checkpoint_manager.save( - trainer.train_state, - checkpoint_cfg.save.state_transformation_fns, # pytype: disable=attribute-error - ) - - # If we take manual control of the garbage collector, we need to disable it - # before starting training. - if gc_period: - gc.disable() - - # ---------------------------------------------------------------------------- - # Main training loop - # ---------------------------------------------------------------------------- - logging.info('Starting training loop.') - - def _cleanup() -> None: - """Ensures everything has been closed upon completion.""" - trainer.close() - if evaluator: - evaluator.close() - utils.sync_global_devices('complete') - logging.info('Finished.') - - first_step = host_step - - if total_steps < first_step: - raise ValueError( - f'Unexpected total_steps ({total_steps}) < checkpoint step ' - f' ({first_step}).' - ) - elif total_steps == first_step: - logging.warning( - 'Total training steps and checkpoint step were both %d, so no training ' - 'will be done. If you are only doing evaluation, this is expected. ' - 'Stopping now.', - total_steps, - ) - _cleanup() - return host_step, trainer.train_state - - logging.info('Starting main loop over steps %d-%d', first_step, total_steps) - - steps_per_epoch = min(steps_per_epoch, total_steps) - first_epoch = first_step // steps_per_epoch - num_epochs = first_epoch + math.ceil( - (total_steps - first_step) / steps_per_epoch - ) - logging.info( - 'Training with artificial "epochs" of %d steps.', steps_per_epoch - ) - - logging.info('Compiling train loop.') - logging.flush() - - def _as_gda(spec): - dummy = np.ones((data_layout.batch_size, *spec.shape[1:]), spec.dtype) - return jax.make_array_from_callback( - dummy.shape, - jax.sharding.NamedSharding( - partitioner.mesh, partitioner.data_partition_spec - ), - lambda idx: dummy[idx], - ) - - # Construct dummy batch for compiling the model. - dummy_batch = jax.tree.map(_as_gda, train_iter.element_spec) - if not isinstance(dummy_batch, Mapping): - raise ValueError( - 'Training loop expects batches to have type ' - f'Mapping[str, np.ndarray] but got {type(dummy_batch)}.' - ) - - assert isinstance(dummy_batch, Mapping) - trainer.compile_train(dummy_batch) - - # ---------------------------------------------------------------------------- - # Warmup input pipeline. - # ---------------------------------------------------------------------------- - train_iter_warmup_tick = time.time() - # We are cheating here. The input pipeline already started warmup when - # first_batch_ready was created. The warmup was then interleaved with the - # model compilation above. We just measure the additional time needed. - first_batch_ready.result() - train_iter_warmup_tock = time.time() - train_metrics.write_scalar( - 'timing/train_iter_warmup', - train_iter_warmup_tock - train_iter_warmup_tick, - host_step, - ) - - jax.monitoring.record_event_duration_secs( - '/jax/t5x/train/time_before_first_step_secs', time.time() - _IMPORT_TIME - ) - - # Current index within checkpoint_steps list for faster lookup runtime and - # for creating a checkpoint if needed between stats_period iterations. - checkpoint_steps_index = 0 - - # Main Loop over "epochs". - for epoch in range(first_epoch, num_epochs): - final_epoch = epoch == num_epochs - 1 - logging.info('Epoch %d of %d', epoch, num_epochs) - - # `stop_training` is requested, break out the main loop immediately. - if trainer.stop_training: - break - - logging.info('BEGIN Train loop.') - try: - # Until the last epoch, `num_steps = steps_per_epoch` - epoch_end_step = first_step + steps_per_epoch * (epoch - first_epoch + 1) - epoch_end_step = min(total_steps, epoch_end_step) - logging.info('Training for %d steps.', epoch_end_step - host_step) - while host_step < epoch_end_step: - if trainer.stop_training: - if checkpoint_period and checkpoint_manager: - logging.info('Saving a checkpoint before early stopping...') - checkpoint_manager.save( - trainer.train_state, - checkpoint_cfg.save.state_transformation_fns, # pytype: disable=attribute-error - ) - logging.info( - 'Stopping training loop early since `stop_training` is requested.' - ) - break - inner_num_steps = min(epoch_end_step - host_step, stats_period) - - # first index in checkpoint_steps list will not always be 0 (in cases - # where first_step is non-zero, for example), so we must iterate to the - # first un-trained step in checkpoint_steps list to not re-train / - # save old steps - checkpoint_steps_index = utils.find_first_checkpoint_step( - checkpoint_steps_index, checkpoint_steps, first_step, host_step - ) - # check if inner_num_steps will skip a checkpoint_step that must be - # saved, if so, then iterate only to that step and save a checkpoint - # at that step and then continue with further iterations - is_checkpoint_step = False - (inner_num_steps, is_checkpoint_step) = utils.find_next_checkpoint_step( - checkpoint_steps_index, - inner_num_steps, - is_checkpoint_step, - host_step, - checkpoint_steps, - epoch_end_step, - checkpoint_period, - first_step, - ) - # Handled separately if this is the overall last step. - if host_step + inner_num_steps == total_steps: - is_checkpoint_step = False - - train_summary = trainer.train( - train_iter, inner_num_steps, start_step=host_step - ) - if not concurrent_metrics: - # Note that we always pass the dictionary of `tasks` -> summary so - # that the actions can be performed without special casing. The - # only caveat is that train would need its own special `key` - # given no `task` will be applied. - trainer.stop_training = run_actions( # pytype: disable=wrong-arg-types # jax-ndarray - trainer_lib.ActionMode.TRAIN, - actions, - trainer.train_state, - {TRAIN_METRIC_KEY: train_summary.result()}, - ) - - if is_checkpoint_step and checkpoint_manager: - logging.info('Saving a checkpoint at specified checkpoint step') - checkpoint_manager.save( - trainer.train_state, - checkpoint_cfg.save.state_transformation_fns, # pytype: disable=attribute-error - ) - # Increment the checkpoint_step_index only if a checkpoint was saved. - if ( - checkpoint_steps - and checkpoint_steps_index < len(checkpoint_steps) - 1 - ): - checkpoint_steps_index += 1 - host_step += inner_num_steps - logging.info('END Train loop.') - except trainer_lib.PreemptionError as e: - if checkpoint_period and checkpoint_manager: - logging.info('Saving emergency checkpoint.') - checkpoint_manager.save( - trainer.train_state, - checkpoint_cfg.save.state_transformation_fns, # pytype: disable=attribute-error - ) - checkpoint_manager.wait_until_finished() - logging.info('Saving emergency checkpoint done.') - raise e - - step_offset = host_step - first_step - - if gc_period and (final_epoch or step_offset % gc_period == 0): - gc.collect() - - # Maybe save a checkpoint if step is at period. - if ( - checkpoint_period - and (final_epoch or step_offset % checkpoint_period == 0) - and checkpoint_manager - ): - train_summary.result() - logging.info('Saving checkpoint.') - checkpoint_tick = time.time() - # Make sure last train step has completed before starting the clock. - checkpoint_manager.save( - trainer.train_state, - checkpoint_cfg.save.state_transformation_fns, # pytype: disable=attribute-error - ) - # `_run_training_eval`` depends upon the result of the checkpoint, - # thus calling `wait_until_finished()`` here. - checkpoint_manager.wait_until_finished() - checkpoint_tock = time.time() - train_metrics.write_scalar( - 'timing/checkpoint_seconds', - checkpoint_tock - checkpoint_tick, - host_step, - ) - - is_eval_epoch = eval_period and ( - final_epoch or step_offset % eval_period == 0 - ) - - # Training Evaluation (i.e., with teacher forcing). - if is_eval_epoch and train_eval_datasets: - # Maybe less if final step < period. - first_run = step_offset // eval_period <= 1 - _run_training_eval(first_run and not run_eval_before_training) - - # Inference Evaluation (i.e., with decoding or scoring). - if is_eval_epoch and evaluator is not None: - _run_inference_eval() - if checkpoint_manager: - checkpoint_manager.close() - - # Wait until computations are done before exiting - _cleanup() - - if gc_period: - # Reenable garbage collection to avoid affecting future code executed in - # the same interpreter. - gc.enable() - - return host_step, trainer.train_state - - -if __name__ == '__main__': - # pylint: disable=g-import-not-at-top - from absl import app - from absl import flags - import fiddle as fdl - import gin - from t5x import config_utils - from t5x import gin_utils - # pylint: enable=g-import-not-at-top - - FLAGS = flags.FLAGS - - flags.DEFINE_multi_string( - 'gin_file', - default=None, - help=( - 'Path to gin configuration file. Multiple paths may be passed and ' - 'will be imported in the given order, with later configurations ' - 'overriding earlier ones.' - ), - ) - - flags.DEFINE_multi_string( - 'gin_bindings', default=[], help='Individual gin bindings.' - ) - - flags.DEFINE_list( - 'gin_search_paths', - default=['.'], - help=( - 'Comma-separated list of gin config path prefixes to be prepended ' - 'to suffixes given via `--gin_file`. If a file appears in. Only the ' - 'first prefix that produces a valid path for each suffix will be ' - 'used.' - ), - ) - - flags.DEFINE_string( - 'tfds_data_dir', - None, - 'If set, this directory will be used to store datasets prepared by ' - 'TensorFlow Datasets that are not available in the public TFDS GCS ' - 'bucket. Note that this flag overrides the `tfds_data_dir` attribute of ' - 'all `Task`s.', - ) - - flags.DEFINE_list( - 'seqio_additional_cache_dirs', - [], - 'Directories to search for cached Tasks in addition to defaults.', - ) - - flags.DEFINE_boolean( - 'multiprocess_gpu', - False, - help=( - 'Initialize JAX distributed system for multi-host GPU, using ' - '`coordinator_address`, `process_count`, and `process_index`.' - ), - ) - - flags.DEFINE_string( - 'coordinator_address', - None, - help='IP address:port for multi-host GPU coordinator.', - ) - - flags.DEFINE_integer( - 'process_count', None, help='Number of processes for multi-host GPU.' - ) - - flags.DEFINE_integer('process_index', None, help='Index of this process.') - flags.DEFINE_integer( - 'initialization_timeout', - None, - help=( - 'Timeout for jax.distributed.initialize. Default used is the ' - 'default as specified in jax.distributed.initialize. ' - ), - ) - - - def main(argv: Sequence[str]): - """Wrapper for pdb post mortems.""" - _main(argv) - - def _main(argv: Sequence[str]): - """True main function.""" - if len(argv) > 1: - raise app.UsageError('Too many command-line arguments.') - - # OOM fix. Prevents TF from seeing GPUs to stop conflict with JAX. - # This must go after InitGoogle(), which is called by - # gin_utils.run(main). - tf.config.experimental.set_visible_devices([], 'GPU') - - - if FLAGS.multiprocess_gpu: - logging.info( - 'Initializing distributed system for multi-host GPU:\n' - ' coordinator_address: %s\n process_count: %s\n process_index: %s', - FLAGS.coordinator_address, - FLAGS.process_count, - FLAGS.process_index, - ) - - if FLAGS.initialization_timeout: - if jax.__version__ < '0.4.15': - raise ValueError( - 'Specified' - f' --initialization_timeout={FLAGS.initialization_timeout}, but' - ' jax=={jax.__version__} does not support this yet. Use' - ' jax>=0.4.15' - ) - jax.distributed.initialize( - FLAGS.coordinator_address, - FLAGS.process_count, - FLAGS.process_index, - initialization_timeout=FLAGS.initialization_timeout, - ) - else: - jax.distributed.initialize( - FLAGS.coordinator_address, FLAGS.process_count, FLAGS.process_index - ) - - if FLAGS.tfds_data_dir: - seqio.set_tfds_data_dir_override(FLAGS.tfds_data_dir) - - seqio.add_global_cache_dirs(FLAGS.seqio_additional_cache_dirs) - - - if config_utils.using_fdl(): - config = config_utils.config_with_fiddle(train) - train_using_fiddle = fdl.build(config) - train_using_fiddle() - else: - # Create gin-configurable version of `train`. - train_using_gin = gin.configurable(train) - - gin_utils.parse_gin_flags( - # User-provided gin paths take precedence if relative paths conflict. - FLAGS.gin_search_paths + _DEFAULT_GIN_SEARCH_PATHS, - FLAGS.gin_file, - FLAGS.gin_bindings, - ) - train_using_gin() - - jax.effects_barrier() - - - config_utils.run(main) diff --git a/t5x-main/t5x/train_state.py b/t5x-main/t5x/train_state.py deleted file mode 100644 index 4ede9ea18b3126e49173ae82f41f48f9cd9c9076..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/train_state.py +++ /dev/null @@ -1,318 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Train state for passing around objects during training.""" - -from typing import Any, Mapping, MutableMapping, Optional, Tuple - -from flax import traverse_util -import flax.core -from flax.core import scope as flax_scope -from flax.linen import partitioning as flax_partitioning -import flax.serialization -import flax.struct -import jax.numpy as jnp -from t5x import optimizers -import typing_extensions - -EMPTY_DICT = flax.core.freeze({}) -FrozenDict = flax_scope.FrozenDict -FrozenVariableDict = flax_scope.FrozenVariableDict -MutableVariableDict = flax_scope.MutableVariableDict -VariableDict = flax_scope.VariableDict - - -@typing_extensions.runtime_checkable -class TrainState(typing_extensions.Protocol): - """TrainState interface.""" - - @property - def step(self) -> jnp.ndarray: - """The current training step as an integer scalar.""" - ... - - @property - def params(self) -> FrozenVariableDict: - """The parameters of the model as a PyTree matching the Flax module.""" - ... - - @property - def param_states(self) -> FrozenVariableDict: - """The optimizer states of the parameters as a PyTree.""" - ... - - @property - def flax_mutables(self) -> FrozenVariableDict: - """Flax mutable collection.""" - ... - - def state_dict(self) -> MutableVariableDict: - """Returns a mutable representation of the state for checkpointing.""" - ... - - def restore_state(self, state_dict: Mapping[str, Any]) -> 'TrainState': - """Restores the object state from a state dict.""" - ... - - def replace_params(self, params: VariableDict) -> 'TrainState': - ... - - def replace_flax_mutables(self, flax_mutables: FrozenDict) -> 'TrainState': - ... - - def replace_step(self, step: jnp.ndarray) -> 'TrainState': - ... - - def apply_gradient( - self, grads, learning_rate, flax_mutables=EMPTY_DICT - ) -> 'TrainState': - """Applies gradient, increments step, and returns an updated TrainState.""" - ... - - def as_logical_axes(self) -> 'TrainState': - """Replaces `param` and `param-states` with their logical axis names.""" - ... - - -def _validate_params_axes(params_axes, params): - axis_names = flax_partitioning.get_axis_names(params_axes) - missing_params_axes = set(traverse_util.flatten_dict(params, sep='/')) - set( - traverse_util.flatten_dict(axis_names, sep='/') - ) - if missing_params_axes: - raise ValueError( - f'Missing axis names for parameters: {missing_params_axes}' - ) - - -def _split_variables_and_axes( - variables_and_axes: FrozenVariableDict, -) -> Tuple[FrozenVariableDict, FrozenVariableDict]: - """Splits `variables_and_axes` into two separate dicts with the same keys.""" - # For each `key`, `key_axes` (if any) are its axes in `variables_and_axes`. - variables = {} - axes = {} - for k, v in variables_and_axes.items(): - if k.endswith('_axes'): - axes[k[:-5]] = v # k without "_axes". - _validate_params_axes(v, variables_and_axes[k[:-5]]) # k without "_axes". - else: - variables[k] = v - return flax.core.freeze(variables), flax.core.freeze(axes) - - -class FlaxOptimTrainState(flax.struct.PyTreeNode): - """Simple train state for holding parameters, step, optimizer state.""" - - _optimizer: optimizers.OptimizerType - # Contains axis metadata (e.g., names) matching parameter tree. - params_axes: Optional[FrozenVariableDict] = None - # Flax mutable fields. - flax_mutables: FrozenDict = EMPTY_DICT - # Contains axis metadata (e.g., names) matching flax_mutables tree. - flax_mutables_axes: Optional[FrozenVariableDict] = None - - @classmethod - def create( - cls, - optimizer_def: optimizers.OptimizerDefType, - model_variables: FrozenVariableDict, - ) -> 'FlaxOptimTrainState': - other_variables, params = flax.core.frozen_dict.pop( - model_variables, 'params' - ) - if 'params_axes' in other_variables: - other_variables, params_axes = flax.core.frozen_dict.pop( - other_variables, 'params_axes' - ) - _validate_params_axes(params_axes, params) - else: - params_axes = None - - # Split other_variables into mutables and their corresponding axes. - flax_mutables, flax_mutables_axes = _split_variables_and_axes( - other_variables - ) - - # If the optimizer supports `set_param_axes`, then assume that the model - # code is emitting these axes and use it. - if hasattr(optimizer_def, 'set_param_axes'): - if params_axes is None: - raise ValueError( - 'The optimizer supports params_axes for model-based ' - 'partitioning, but the model is not emitting them.' - ) - # `get_axis_names` removes "_axes" suffix in the leaf name and replaces - # `AxisMetadata` with `PartitionSpec`. - axis_names = flax_partitioning.get_axis_names(params_axes) - optimizer_def.set_param_axes(axis_names) - - optimizer = optimizer_def.create(params) - flax_mutables_axes = flax_mutables_axes or None - return FlaxOptimTrainState( - optimizer, - params_axes=params_axes, - flax_mutables=flax_mutables, - flax_mutables_axes=flax_mutables_axes, - ) - - @property - def step(self) -> jnp.ndarray: - return self._optimizer.state.step - - @property - def params(self) -> FrozenVariableDict: - return self._optimizer.target - - @property - def param_states(self) -> FrozenVariableDict: - return self._optimizer.state.param_states - - def state_dict(self) -> MutableVariableDict: - state_dict = self._optimizer.state_dict() - if self.flax_mutables: - state_dict['flax_mutables'] = flax.core.unfreeze(self.flax_mutables) - return state_dict - - def apply_gradient( - self, grads, learning_rate, flax_mutables=EMPTY_DICT - ) -> 'FlaxOptimTrainState': - new_optimizer = self._optimizer.apply_gradient( - grads, learning_rate=learning_rate - ) - return self.replace(_optimizer=new_optimizer, flax_mutables=flax_mutables) - - def replace_params(self, params: VariableDict) -> 'FlaxOptimTrainState': - return self.replace(_optimizer=self._optimizer.replace(target=params)) - - def replace_flax_mutables( - self, flax_mutables: FrozenDict - ) -> 'FlaxOptimTrainState': - return self.replace(flax_mutables=flax_mutables) - - def replace_step(self, step: jnp.ndarray) -> 'FlaxOptimTrainState': - state_dict = self.state_dict() - state_dict['state']['step'] = step - return self.restore_state(state_dict) - - def restore_state(self, state_dict: VariableDict) -> 'FlaxOptimTrainState': - new_optimizer = self._optimizer.restore_state(state_dict) - return self.replace( - _optimizer=new_optimizer, - flax_mutables=flax.core.freeze(state_dict['flax_mutables']) - if 'flax_mutables' in state_dict - else EMPTY_DICT, - ) - - def as_logical_axes(self) -> 'FlaxOptimTrainState': - if not hasattr(self._optimizer.optimizer_def, 'derive_logical_axes'): - raise ValueError( - f"Optimizer '{self._optimizer.optimizer_def.__class__.__name__}' " - 'requires a `derive_logical_axes` method to be used with named axis ' - 'partitioning.' - ) - flax_mutables_axes = self.flax_mutables_axes or EMPTY_DICT - return FlaxOptimTrainState( - _optimizer=self._optimizer.optimizer_def.derive_logical_axes( - self._optimizer, flax_partitioning.get_axis_names(self.params_axes) - ), - flax_mutables=flax_partitioning.get_axis_names(flax_mutables_axes), - ) - - -class InferenceState(flax.struct.PyTreeNode): - """State compatible with FlaxOptimTrainState without optimizer state.""" - - step: Optional[jnp.ndarray] - params: flax_scope.FrozenVariableDict - params_axes: Optional[flax_scope.FrozenVariableDict] = None - flax_mutables: flax_scope.FrozenDict = EMPTY_DICT - flax_mutables_axes: Optional[flax_scope.FrozenVariableDict] = None - - @classmethod - def create(cls, model_variables: FrozenVariableDict) -> 'InferenceState': - other_variables, params = flax.core.frozen_dict.pop( - model_variables, 'params' - ) - if 'params_axes' in other_variables: - other_variables, params_axes = flax.core.frozen_dict.pop( - other_variables, 'params_axes' - ) - _validate_params_axes(params_axes, params) - else: - params_axes = None - - # Split other_variables into mutables and their corresponding axes. - flax_mutables, flax_mutables_axes = _split_variables_and_axes( - other_variables - ) - flax_mutables_axes = flax_mutables_axes or None - return InferenceState( - step=jnp.array(0), - params=params, - params_axes=params_axes, - flax_mutables=flax_mutables, - flax_mutables_axes=flax_mutables_axes, - ) - - @property - def param_states(self) -> FrozenVariableDict: - """The optimizer states of the parameters as a PyTree.""" - raise NotImplementedError('InferenceState has no optimizer states.') - - def apply_gradient(self, *args, **kwargs) -> 'InferenceState': - raise NotImplementedError( - 'InferenceState does not support `apply_gradient`.' - ) - - def state_dict(self) -> MutableMapping[str, Any]: - state_dict = { - 'target': flax.core.unfreeze(self.params), - 'state': {'step': self.step}, - } - if self.flax_mutables: - state_dict['flax_mutables'] = flax.core.unfreeze(self.flax_mutables) - return state_dict - - def replace_step(self, step: jnp.ndarray) -> 'InferenceState': - return self.replace(step=step) - - def replace_params(self, params: FrozenVariableDict) -> 'InferenceState': - return self.replace(params=params) - - def replace_flax_mutables( - self, flax_mutables: FrozenDict - ) -> 'InferenceState': - return self.replace(flax_mutables=flax_mutables) - - def restore_state(self, state_dict: Mapping[str, Any]) -> 'InferenceState': - return self.replace( - params=flax.core.freeze(state_dict['target']), - step=state_dict['state']['step'], - flax_mutables=flax.core.freeze(state_dict['flax_mutables']) - if 'flax_mutables' in state_dict - else EMPTY_DICT, - ) - - def as_logical_axes(self) -> 'InferenceState': - # Set step to None so that when the logical axes are processed by the - # flax.partitioning.logical_to_mesh_axes function, it will be skipped - # because jax.tree.map will short circut and never call the function on the - # step. - flax_mutables_axes = self.flax_mutables_axes or EMPTY_DICT - return InferenceState( - step=None, - params=flax_partitioning.get_axis_names(self.params_axes), - flax_mutables=flax_partitioning.get_axis_names(flax_mutables_axes), - ) diff --git a/t5x-main/t5x/train_state_test.py b/t5x-main/t5x/train_state_test.py deleted file mode 100644 index 1e37d4674e722c2a25110d0a127e550d8b3456d7..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/train_state_test.py +++ /dev/null @@ -1,660 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for train_state.""" - -from absl.testing import absltest -from flax import linen as nn -import flax.core -from flax.linen import partitioning as flax_partitioning -import jax -import numpy as np -from t5x import adafactor -from t5x import optimizers -from t5x import partitioning -from t5x import train_state as train_state_lib - -mock = absltest.mock -AxisMetadata = flax_partitioning.AxisMetadata -FactorDim = adafactor.FactorDim - - -class FlaxOptimTrainStateTest(absltest.TestCase): - - def test_init(self): - model = nn.Dense(10) - inputs = np.ones([2, 3], dtype=np.float32) - params = model.init(jax.random.PRNGKey(0), inputs)['params'] - optimizer_def = optimizers.adam(0.1) - optimizer = optimizer_def.create(params) - flax_mutables = flax.core.freeze({'flax_mutable1': np.ones(10)}) - state = train_state_lib.FlaxOptimTrainState( - optimizer, flax_mutables=flax_mutables - ) - self.assertEqual(state.step, 0) - self.assertIsInstance(state._optimizer, optimizers.Optimizer) - self.assertEqual( - state.state_dict()['flax_mutables'], flax.core.unfreeze(flax_mutables) - ) - jax.tree.map(np.testing.assert_array_equal, params, state.params) - jax.tree.map( - np.testing.assert_array_equal, - optimizer.state.param_states, - state.param_states, - ) - - def test_create(self): - model_variables = flax.core.freeze({ - 'params': {'dense': {'bias': np.zeros(4), 'kernel': np.zeros((2, 4))}}, - 'mutables': np.ones(3), - }) - optimizer_def = optimizers.sgd(0.42) - state = train_state_lib.FlaxOptimTrainState.create( - optimizer_def, model_variables - ) - self.assertEqual(state.step, 0) - self.assertIsInstance(state._optimizer, optimizers.Optimizer) - self.assertEqual(state._optimizer.optimizer_def, optimizer_def) - jax.tree.map( - np.testing.assert_array_equal, - state.flax_mutables, - flax.core.freeze({'mutables': np.ones(3)}), - ) - jax.tree.map( - np.testing.assert_array_equal, state.params, model_variables['params'] - ) - self.assertIsNone(state.params_axes) - - def test_create_with_params_axes(self): - model_variables = flax.core.freeze({ - 'params': {'dense': {'bias': np.zeros(4), 'kernel': np.zeros((2, 4))}}, - 'params_axes': { - 'dense': { - 'bias_axes': AxisMetadata(names=('embed',)), - 'kernel_axes': AxisMetadata(names=('vocab', 'embed')), - } - }, - }) - optimizer_def = adafactor.Adafactor( - 0.42, - logical_factor_rules={ - 'vocab': FactorDim.COLUMN, - 'embed': FactorDim.ROW, - }, - ) - state = train_state_lib.FlaxOptimTrainState.create( - optimizer_def, model_variables - ) - self.assertEqual(state.step, 0) - self.assertIsInstance(state._optimizer, optimizers.Optimizer) - self.assertEqual(state._optimizer.optimizer_def, optimizer_def) - self.assertDictEqual( - state._optimizer.optimizer_def.hyper_params.factor_map, - { - 'dense/bias': (FactorDim.NONE,), - 'dense/kernel': (FactorDim.COLUMN, FactorDim.ROW), - }, - ) - self.assertEqual(state.flax_mutables, flax.core.freeze({})) - jax.tree.map( - np.testing.assert_array_equal, model_variables['params'], state.params - ) - jax.tree.map( - np.testing.assert_array_equal, - model_variables['params_axes'], - state.params_axes, - ) - - def test_create_with_flax_mutables_axes(self): - model_variables = flax.core.freeze({ - 'params': {'dense': {'bias': np.zeros(4), 'kernel': np.zeros((2, 4))}}, - 'params_axes': { - 'dense': { - 'bias_axes': AxisMetadata(names=('embed',)), - 'kernel_axes': AxisMetadata(names=('vocab', 'embed')), - } - }, - 'grads': { - 'dense': { - 'output_grad': np.zeros(4), - } - }, - 'grads_axes': { - 'dense': { - 'output_grad': AxisMetadata(names=('embed',)), - } - }, - }) - optmizer_def = adafactor.Adafactor( - 0.42, - logical_factor_rules={ - 'vocab': FactorDim.COLUMN, - 'embed': FactorDim.ROW, - }, - ) - state = train_state_lib.FlaxOptimTrainState.create( - optmizer_def, model_variables - ) - self.assertEqual(state.step, 0) - self.assertIsInstance(state._optimizer, optimizers.Optimizer) - self.assertEqual(state._optimizer.optimizer_def, optmizer_def) - self.assertDictEqual( - state._optimizer.optimizer_def.hyper_params.factor_map, - { - 'dense/bias': (FactorDim.NONE,), - 'dense/kernel': (FactorDim.COLUMN, FactorDim.ROW), - }, - ) - self.assertEqual( - state.flax_mutables, - flax.core.freeze({'grads': model_variables['grads']}), - ) - jax.tree.map( - np.testing.assert_array_equal, model_variables['params'], state.params - ) - jax.tree.map( - np.testing.assert_array_equal, - model_variables['params_axes'], - state.params_axes, - ) - jax.tree.map( - np.testing.assert_array_equal, - model_variables['grads_axes'], - state.flax_mutables_axes['grads'], - ) - - def test_create_missing_params_axes(self): - model_variables = flax.core.freeze({ - 'params': {'dense': {'bias': np.zeros(4), 'kernel': np.zeros((2, 4))}}, - 'mutables': np.ones(3), - }) - with self.assertRaisesWithLiteralMatch( - ValueError, - 'The optimizer supports params_axes for model-based partitioning, but ' - 'the model is not emitting them.', - ): - train_state_lib.FlaxOptimTrainState.create( - adafactor.Adafactor(), model_variables - ) - - def test_create_mismatched_params_axes(self): - model_variables = flax.core.freeze({ - 'params': {'dense': {'bias': np.zeros(4), 'kernel': np.zeros((2, 4))}}, - 'params_axes': { - 'dense': { - 'bias_axes': AxisMetadata(names=('embed',)), - } - }, - 'mutables': np.ones(3), - }) - with self.assertRaisesWithLiteralMatch( - ValueError, "Missing axis names for parameters: {'dense/kernel'}" - ): - train_state_lib.FlaxOptimTrainState.create( - adafactor.Adafactor(), model_variables - ) - - def test_replace_params(self): - optimizer_def = optimizers.sgd(0.1) - optimizer = optimizer_def.create({'test': np.ones(10)}) - state = train_state_lib.FlaxOptimTrainState(optimizer) - - new_params = {'test': np.zeros(10)} - new_state = state.replace_params(new_params) - jax.tree.map(np.testing.assert_array_equal, new_params, new_state.params) - expected_state_dict = state.state_dict() - expected_state_dict['target'] = new_params - jax.tree.map( - np.testing.assert_array_equal, - expected_state_dict, - new_state.state_dict(), - ) - - def test_replace_step(self): - optimizer_def = optimizers.adam(0.1) - optimizer = optimizer_def.create({'test': np.ones(10)}) - state = train_state_lib.FlaxOptimTrainState(optimizer) - - self.assertEqual(state.step, 0) - self.assertEqual(state.replace_step(jax.numpy.array(1)).step, 1) - - def test_apply_gradient(self): - updated_optimizer = object() - optimizer = mock.Mock( - apply_gradient=mock.Mock(return_value=updated_optimizer) - ) - state = train_state_lib.FlaxOptimTrainState(optimizer) - - new_flax_mutables = {'test': 44} - new_state = state.apply_gradient( - grads=42, learning_rate=43, flax_mutables={'test': 44} - ) - - optimizer.apply_gradient.assert_called_once_with(42, learning_rate=43) - - self.assertEqual(new_state._optimizer, updated_optimizer) - self.assertEqual( - new_state, - train_state_lib.FlaxOptimTrainState( - updated_optimizer, flax_mutables=new_flax_mutables - ), - ) - - def test_as_logical_axes(self): - model_variables = flax.core.freeze({ - 'params': {'dense': {'bias': np.zeros(4), 'kernel': np.zeros((2, 4))}}, - 'params_axes': { - 'dense': { - 'bias_axes': AxisMetadata(names=('embed',)), - 'kernel_axes': AxisMetadata(names=('vocab', 'embed')), - } - }, - }) - optimizer_def = adafactor.Adafactor( - 0.42, - logical_factor_rules={ - 'vocab': FactorDim.COLUMN, - 'embed': FactorDim.ROW, - }, - ) - state = train_state_lib.FlaxOptimTrainState.create( - optimizer_def, model_variables - ) - axes_state = state.as_logical_axes() - self.assertIsNone(axes_state.params_axes) - jax.tree.map( - np.testing.assert_array_equal, - axes_state.params, - flax.core.freeze({ - 'dense': { - 'bias': partitioning.PartitionSpec('embed'), - 'kernel': partitioning.PartitionSpec('vocab', 'embed'), - } - }), - ) - - def test_as_logical_axes_with_flax_mutables(self): - model_variables = flax.core.freeze({ - 'params': {'dense': {'bias': np.zeros(4), 'kernel': np.zeros((2, 4))}}, - 'params_axes': { - 'dense': { - 'bias_axes': AxisMetadata(names=('embed',)), - 'kernel_axes': AxisMetadata(names=('vocab', 'embed')), - } - }, - 'grads': { - 'dense': { - 'output_grad': np.zeros(4), - } - }, - 'grads_axes': { - 'dense': { - 'output_grad': AxisMetadata(names=('embed',)), - } - }, - }) - optmizer_def = adafactor.Adafactor( - 0.42, - logical_factor_rules={ - 'vocab': FactorDim.COLUMN, - 'embed': FactorDim.ROW, - }, - ) - state = train_state_lib.FlaxOptimTrainState.create( - optmizer_def, model_variables - ) - axes_state = state.as_logical_axes() - self.assertIsNone(axes_state.params_axes) - jax.tree.map( - np.testing.assert_array_equal, - axes_state.flax_mutables, - flax.core.freeze({ - 'grads': { - 'dense': { - 'output_grad': partitioning.PartitionSpec('embed'), - } - } - }), - ) - - def test_as_logical_axes_with_flax_mutables_without_mutables_axes(self): - model_variables = flax.core.freeze({ - 'params': {'dense': {'bias': np.zeros(4), 'kernel': np.zeros((2, 4))}}, - 'params_axes': { - 'dense': { - 'bias_axes': AxisMetadata(names=('embed',)), - 'kernel_axes': AxisMetadata(names=('vocab', 'embed')), - } - }, - 'some_variable_without_axes_defined': { - 'dense': { - 'kernel': np.zeros(4), - } - }, - }) - optmizer_def = adafactor.Adafactor( - 0.42, - logical_factor_rules={ - 'vocab': FactorDim.COLUMN, - 'embed': FactorDim.ROW, - }, - ) - state = train_state_lib.FlaxOptimTrainState.create( - optmizer_def, model_variables - ) - self.assertIsNone(state.flax_mutables_axes) # Not provided so must be None. - axes_state = state.as_logical_axes() - self.assertIsNone(axes_state.params_axes) - self.assertIsNone(axes_state.flax_mutables_axes) - jax.tree.map( - np.testing.assert_array_equal, - axes_state.flax_mutables, - flax.core.freeze({}), - ) - - def test_to_state_dict(self): - model_variables = flax.core.freeze({ - 'params': {'kernel': np.zeros((2, 4))}, - 'params_axes': { - 'kernel_axes': AxisMetadata(names=('vocab', 'embed')), - }, - 'mutables': np.ones(3), - }) - optimizer_def = adafactor.Adafactor( - 0.42, - logical_factor_rules={ - 'vocab': FactorDim.COLUMN, - 'embed': FactorDim.ROW, - }, - ) - state = train_state_lib.FlaxOptimTrainState.create( - optimizer_def, model_variables - ) - jax.tree.map( - np.testing.assert_array_equal, - state.state_dict(), - { - 'state': { - 'step': np.array(0), - 'param_states': { - 'kernel': { - 'm': np.zeros(1), - 'v': np.zeros((2, 4)), - 'v_col': np.zeros(1), - 'v_row': np.zeros(1), - }, - }, - }, - 'target': {'kernel': np.zeros((2, 4))}, - 'flax_mutables': {'mutables': np.ones(3)}, - }, - ) - - def test_restore_state(self): - model_variables = flax.core.freeze({ - 'params': {'kernel': np.zeros((2, 4))}, - 'params_axes': { - 'kernel_axes': AxisMetadata(names=('vocab', 'embed')), - }, - 'mutables': np.ones(3), - }) - optimizer_def = adafactor.Adafactor( - 0.42, - logical_factor_rules={ - 'vocab': FactorDim.COLUMN, - 'embed': FactorDim.ROW, - }, - ) - state = train_state_lib.FlaxOptimTrainState.create( - optimizer_def, model_variables - ) - restored = state.restore_state({ - 'state': { - 'step': np.array(1), - 'param_states': { - 'kernel': { - 'm': np.ones(1), - 'v': np.ones((2, 4)), - 'v_col': np.ones(1), - 'v_row': np.ones(1), - }, - }, - }, - 'target': {'kernel': np.ones((2, 4))}, - 'flax_mutables': {'mutables': np.zeros(3)}, - }) - - self.assertEqual(restored.step, 1) - self.assertIsInstance(restored._optimizer, optimizers.Optimizer) - self.assertEqual(restored._optimizer.optimizer_def, optimizer_def) - jax.tree.map( - np.testing.assert_array_equal, - restored.flax_mutables, - flax.core.freeze({'mutables': np.zeros(3)}), - ) - jax.tree.map( - np.testing.assert_array_equal, - restored.params, - flax.core.freeze({'kernel': np.ones((2, 4))}), - ) - jax.tree.map( - np.testing.assert_array_equal, - restored.param_states, - flax.core.freeze({ - 'kernel': adafactor._AdafactorParamState( - np.ones(1), np.ones(1), np.ones((2, 4)), np.ones(1) - ) - }), - ) - jax.tree.map( - np.testing.assert_array_equal, - restored.params_axes, - model_variables['params_axes'], - ) - - -class InferenceStateTest(absltest.TestCase): - - def test_init(self): - model = nn.Dense(10) - inputs = np.ones([2, 3], dtype=np.float32) - params = model.init(jax.random.PRNGKey(0), inputs)['params'] - flax_mutables = flax.core.freeze({'flax_mutable1': np.ones(10)}) - state = train_state_lib.InferenceState( - step=jax.numpy.array(1), params=params, flax_mutables=flax_mutables - ) - self.assertEqual(state.step, 1) - self.assertEqual(state.flax_mutables, flax.core.unfreeze(flax_mutables)) - jax.tree.map(np.testing.assert_array_equal, params, state.params) - self.assertIsNone(state.params_axes) - - def test_create(self): - model_variables = flax.core.freeze({ - 'params': {'dense': {'bias': np.zeros(4), 'kernel': np.zeros((2, 4))}}, - 'params_axes': { - 'dense': { - 'bias_axes': AxisMetadata(names=('embed',)), - 'kernel_axes': AxisMetadata(names=('vocab', 'embed')), - } - }, - 'mutables': np.ones(3), - }) - state = train_state_lib.InferenceState.create(model_variables) - self.assertEqual(state.step, 0) - jax.tree.map( - np.testing.assert_array_equal, - state.flax_mutables, - flax.core.freeze({'mutables': np.ones(3)}), - ) - jax.tree.map( - np.testing.assert_array_equal, state.params, model_variables['params'] - ) - jax.tree.map( - np.testing.assert_array_equal, - state.params_axes, - model_variables['params_axes'], - ) - - def test_create_mismatched_params_axes(self): - model_variables = flax.core.freeze({ - 'params': {'dense': {'bias': np.zeros(4), 'kernel': np.zeros((2, 4))}}, - 'params_axes': { - 'dense': { - 'bias_axes': AxisMetadata(names=('embed',)), - } - }, - 'mutables': np.ones(3), - }) - with self.assertRaisesWithLiteralMatch( - ValueError, "Missing axis names for parameters: {'dense/kernel'}" - ): - train_state_lib.InferenceState.create(model_variables) - - def test_replace_params(self): - model_variables = flax.core.freeze({'params': {'test': np.ones(10)}}) - state = train_state_lib.InferenceState.create(model_variables) - - new_params = {'test': np.zeros(10)} - new_state = state.replace_params(new_params) - jax.tree.map(np.testing.assert_array_equal, new_params, new_state.params) - - def test_replace_step(self): - model_variables = flax.core.freeze({'params': {'test': np.ones(10)}}) - state = train_state_lib.InferenceState.create(model_variables) - - self.assertEqual(state.step, 0) - self.assertEqual(state.replace_step(jax.numpy.array(1)).step, 1) - - def test_as_logical_axes(self): - model_variables = flax.core.freeze({ - 'params': {'dense': {'bias': np.zeros(4), 'kernel': np.zeros((2, 4))}}, - 'params_axes': { - 'dense': { - 'bias_axes': AxisMetadata(names=('embed',)), - 'kernel_axes': AxisMetadata(names=('vocab', 'embed')), - } - }, - }) - state = train_state_lib.InferenceState.create(model_variables) - axes_state = state.as_logical_axes() - self.assertIsNone(axes_state.params_axes) - jax.tree.map( - np.testing.assert_array_equal, - axes_state.params, - flax.core.freeze({ - 'dense': { - 'bias': partitioning.PartitionSpec('embed'), - 'kernel': partitioning.PartitionSpec('vocab', 'embed'), - } - }), - ) - - def test_to_state_dict(self): - model_variables = flax.core.freeze({ - 'params': { - 'bias': np.zeros(4), - }, - 'params_axes': { - 'bias_axes': AxisMetadata(names=('embed',)), - }, - 'mutables': np.ones(3), - }) - state = train_state_lib.InferenceState.create(model_variables) - jax.tree.map( - np.testing.assert_array_equal, - state.state_dict(), - { - 'state': {'step': np.array(0)}, - 'target': { - 'bias': np.zeros(4), - }, - 'flax_mutables': {'mutables': np.ones(3)}, - }, - ) - - def test_to_state_dict_no_mutables(self): - model_variables = flax.core.freeze({ - 'params': { - 'bias': np.zeros(4), - }, - 'params_axes': { - 'bias_axes': AxisMetadata(names=('embed',)), - }, - }) - state = train_state_lib.InferenceState.create(model_variables) - jax.tree.map( - np.testing.assert_array_equal, - state.state_dict(), - { - 'state': {'step': np.array(0)}, - 'target': { - 'bias': np.zeros(4), - }, - }, - ) - - def test_restore_state(self): - state = train_state_lib.InferenceState( - np.array(0), - {'bias': np.zeros(4)}, - {'bias_axes': AxisMetadata(names=('embed',))}, - ) - - state_dict = { - 'state': {'step': np.array(10)}, - 'target': { - 'bias': np.ones(4), - }, - 'flax_mutables': {'mutables': np.ones(3)}, - } - restored = state.restore_state(state_dict) - - self.assertEqual(restored.step, 10) - jax.tree.map( - np.testing.assert_array_equal, - restored.flax_mutables, - flax.core.freeze(state_dict['flax_mutables']), - ) - jax.tree.map( - np.testing.assert_array_equal, - restored.params, - flax.core.freeze(state_dict['target']), - ) - self.assertEqual( - restored.params_axes, {'bias_axes': AxisMetadata(names=('embed',))} - ) - - def test_restore_state_no_mutables_no_axes(self): - state = train_state_lib.InferenceState(np.array(0), {}) - - state_dict = { - 'state': {'step': np.array(10)}, - 'target': { - 'bias': np.zeros(4), - }, - } - restored = state.restore_state(state_dict) - - self.assertEqual(restored.step, 10) - self.assertEqual(restored.flax_mutables, train_state_lib.EMPTY_DICT) - jax.tree.map( - np.testing.assert_array_equal, - restored.params, - flax.core.freeze(state_dict['target']), - ) - self.assertIsNone(restored.params_axes) - - -if __name__ == '__main__': - absltest.main() diff --git a/t5x-main/t5x/trainer.py b/t5x-main/t5x/trainer.py deleted file mode 100644 index 965bd09a124e77f7abad2c7df5f2250519ae21fc..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/trainer.py +++ /dev/null @@ -1,1285 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Trainer and MetricsManager classes for use in train loop. - -To create a custom trainer, subclass `BaseTrainer` and implement -`_partitioned_train_step` and `_partitioned_eval_step` methods, -possibly by re-using the utility functions provided in this module. -""" - -import abc -import enum -import os -import threading -import time -from typing import Any, Dict, Iterator, Mapping, MutableMapping, Optional, Protocol, Sequence, TYPE_CHECKING, Tuple, Union - -from absl import logging -import cached_property -from clu import asynclib -from clu import metric_writers -import clu.data -import clu.metrics -import clu.values -from flax.core import FrozenDict -import jax -import jax.lax -import jax.numpy as jnp -import jax.random -import numpy as np -from t5x import metrics as metrics_lib -from t5x import models -from t5x import partitioning -from t5x import train_state as train_state_lib -from t5x import utils -import typing_extensions - - -Array = Union[np.ndarray, jnp.ndarray] -BatchSpec = Mapping[str, jax.ShapeDtypeStruct] -BatchType = Mapping[str, np.ndarray] -FlaxMutables = FrozenDict -Rng = jnp.ndarray -MetricMapType = MutableMapping[str, clu.metrics.Metric] -MetricMapSpec = Mapping[str, jax.ShapeDtypeStruct] -MetricValueMapType = Mapping[str, clu.values.Value] -ModelWeights = Any -MutableMetricMapType = Dict[str, clu.metrics.Metric] -PyTree = Any -PartitionSpec = partitioning.PartitionSpec - -if TYPE_CHECKING: # See b/163639353 - cached_property = property # pylint: disable=invalid-name -else: - cached_property = cached_property.cached_property - - -@jax.jit -def _merge_metrics(a, b): - return jax.tree_util.tree_map( - lambda a, b: a.merge(b), a, b, is_leaf=metrics_lib.is_metric_obj - ) - - -def _time() -> float: - """Indirection to `time.time` for mocking.""" - return time.time() - - -# Merges two metrics pytrees (mapping of metric_name (str) to clu.Metric object) -def merge_metrics(a, b): - a, b = jax.tree_util.tree_map(utils.get_local_data, (a, b)) - return _merge_metrics(a, b) - - -class ArrayMapFuture(typing_extensions.Protocol): - - def result(self) -> Mapping[str, Array]: - ... - - -class MetricValueMapFuture(typing_extensions.Protocol): - - def result(self) -> Mapping[str, clu.values.Value]: - ... - - -class TimeFuture(typing_extensions.Protocol): - - def result(self) -> float: - ... - - -class LearningRateCallable(typing_extensions.Protocol): - - def __call__( - self, - step: jnp.ndarray, - ) -> jnp.ndarray: - ... - - -class SummarizeMetricsCallable(typing_extensions.Protocol): - """PyType template for a metrics summary function.""" - - def __call__( - self, metrics: MetricMapType, duration: float, num_steps: int - ) -> Mapping[str, jnp.ndarray]: - """Summarizes metrics accumulated across multiple steps. - - Args: - metrics: Metrics accumulated across multiple steps. - duration: The duration of the run being summarized. - num_steps: The number of steps the metrics are accumulated across. - - Returns: - Summarized metrics. - """ - ... - - -class PartitionedTrainCallable(typing_extensions.Protocol): - """Protocol for a partitioned train step.""" - - def __call__( - self, train_state: train_state_lib.TrainState, batch: BatchType - ) -> Tuple[train_state_lib.TrainState, MetricMapType]: - ... - - -class PartitionedEvalCallable(typing_extensions.Protocol): - """Protocol for a partitioned eval step.""" - - def __call__( - self, train_state: train_state_lib.TrainState, batch: jnp.ndarray - ) -> MetricMapType: - ... - - -class WeightMetricsComputer(object): - """Decides which weight metrics to compute during training.""" - - _WEIGHT_METRICS = [ - "weight_rms", - "weight_gradient_rms", - "weight_update_rms", - "weight_max", - ] - - @staticmethod - def _make_rms_metrics(name, tree): - """Calculates the root-mean-square metric for a pytree.""" - return { - f"{name}/{k}": metrics_lib.AveragePerStep.from_model_output( - jnp.sqrt(jnp.mean(jnp.square(v))) - ) - for k, v in utils.flatten_dict_string_keys(tree).items() - } - - @staticmethod - def _make_max_metrics(name, tree): - """Calculates the L-inf norm for a pytree.""" - return { - f"{name}/{k}": metrics_lib.AveragePerStep.from_model_output( - jnp.max(jnp.abs(v)) - ) - for k, v in utils.flatten_dict_string_keys(tree).items() - } - - def compute_metrics( - self, - gradients: ModelWeights, - old_train_state: train_state_lib.TrainState, - new_train_state: train_state_lib.TrainState, - ) -> MutableMetricMapType: - """Compute some metrics about weights after having updating these weights. - - Args: - gradients: The gradients of all weights. - old_train_state: The training state before applying the gradients. - new_train_state: The training state after applying the gradients. - - Returns: - A dictionary of Metrics, where the keys are either metric names, or of the - form metric_name/parameter_name, depending on whether or not they are - global to the model, or specific to each model parameter. - """ - # TODO(reinerp): Extend weight stats logging with support for non-reduced - # axes of tensors. For example, for stacked layers (QKV stacking or layer - # stacking), we might not want to reduce over the stacking dimension, in - # order to provide more localization in the logged stats. - metrics = {} - metrics.update(self._make_rms_metrics("weight_rms", new_train_state.params)) - metrics.update(self._make_rms_metrics("weight_gradient_rms", gradients)) - grad_norm = jnp.sqrt( - jnp.sum( - jnp.array( - [jnp.vdot(x, x) for x in jax.tree_util.tree_leaves(gradients)] - ) - ) - ) - metrics.update({ - "weight_gradient_norm": metrics_lib.AveragePerStep.from_model_output( - grad_norm - ) - }) - weight_update = jax.tree_util.tree_map( - jnp.subtract, new_train_state.params, old_train_state.params - ) - metrics.update(self._make_rms_metrics("weight_update_rms", weight_update)) - weight_update_by_weight = jax.tree_util.tree_map( - jnp.divide, weight_update, old_train_state.params - ) - metrics.update( - self._make_rms_metrics( - "weight_update_divided_by_weight_rms", weight_update_by_weight - ) - ) - metrics.update(self._make_max_metrics("weight_max", new_train_state.params)) - - return metrics - - -class _AsyncTimer(object): - """A timer that computes computes durations between async jax operations. - - You should call close() to wait for threads started by this class to finish. - """ - - def __init__(self): - # We use a thread pool with a single worker to ensure that calls to the - # function are run in order (but in a background thread). - self._pool = asynclib.Pool(thread_name_prefix="AsyncTimer", max_workers=1) - self._start_future = None - - def close(self): - self._pool.close() - - def __del__(self): - self.close() - - def _get_completion_future(self, block_on: PyTree = ()) -> TimeFuture: - """Returns Future containing time when `block_on` is ready.""" - - def _get_completion_time(): - try: - jax.block_until_ready(block_on) - except RuntimeError as e: - # If the buffer no longer exists, we assume it was completed. - buffer_deleted_message = ( - "INVALID_ARGUMENT: BlockHostUntilReady() " - "called on deleted or donated buffer" - ) - gda_buffer_deleted_message = ( - "INVALID_ARGUMENT: GetReadyFuture() " - "called on deleted or donated buffer" - ) - if str(e) not in (buffer_deleted_message, gda_buffer_deleted_message): - raise - return _time() - - return self._pool(_get_completion_time)() - - def start(self, block_on: PyTree = ()): - """Starts timer after `block_on` is ready.""" - self._start_future = self._get_completion_future(block_on) - - def stop(self, block_on: PyTree = ()) -> TimeFuture: - """Stops timer after `block_on` is ready, returning the duration.""" - if not self._start_future: - raise ValueError("The timer hasn't been started.") - - start_future = self._start_future - self._start_future = None - stop_future = self._get_completion_future(block_on) - return self._pool(lambda: stop_future.result() - start_future.result())() - - -class MetricsManager(object): - """Manages a set of distributed metrics and their logging. - - Logging is disabled on all but host 0. - - Logs to: - * TensorBoard - * ABSL - - You should call close() to wait for threads started by this class to finish. - """ - - def __init__( - self, - name: str, - summary_dir: Optional[str] = None, - ): - """MetricsManager constructor. - - Constructs an empty MetricWriter on all but host 0. - - Args: - name: an identifier of the metrics to use when logging (e.g., 'train'). - summary_dir: the summary directory. If provided, TensorBoard summaries - will be written to a `name` subdirectory. - """ - self._name = name - if jax.process_index() == 0: - self._writer = self._create_writer(name, summary_dir) - else: - self._writer = metric_writers.MultiWriter([]) - self.summary_dir = os.path.join(summary_dir, name) if summary_dir else None - self._writer_lock = threading.Lock() - # We use a thread pool with a single worker to ensure that calls to the - # function are run in order (but in a background thread). - self._summary_pool = asynclib.Pool( - thread_name_prefix="MetricsManager", max_workers=1 - ) - # Times the duration between steps. - self._duration_timer = _AsyncTimer() - - def _create_writer( - self, name: str, summary_dir: Optional[str] = None - ) -> metric_writers.MetricWriter: - """Creates the writer for host 0.""" - return metric_writers.create_default_writer( - summary_dir, - collection=name, - asynchronous=True, - ) - - def __del__(self): - self.close() - - def close(self): - try: - self._summary_pool.close() - finally: - try: - self._duration_timer.close() - finally: - if self._writer: - self._writer.close() - self._writer = None - - @property - def summary_writer(self) -> metric_writers.MetricWriter: - """Returns the MetricWriter used by this class.""" - # TODO(adarob): Make returned writer threadsafe. - assert self._writer is not None - return self._writer - - def write_scalar( - self, key: str, val: metric_writers.interface.Scalar, step: int - ): - """Writes scalar value to metric writers in a threadsafe manner.""" - step = int(utils.get_local_data(step)) - self.write_scalars(step, {key: val}) - - def write_scalars( - self, step: int, scalars: Mapping[str, metric_writers.interface.Scalar] - ): - """Writes scalar value to metric writers in a threadsafe manner.""" - step = utils.get_local_data(step) - with self._writer_lock: - assert self._writer is not None - self._writer.write_scalars(step, scalars) - - def start_duration_timer(self, block_on: PyTree = ()): - """Starts the duration timer.""" - self._duration_timer.start(block_on=block_on) - - def write_metrics_summary( - self, metrics: MetricMapType, step: int, num_steps: int - ) -> MetricValueMapFuture: - """Writes summary based on accumulated metrics in a background thread. - - Duration is automatically computed as the interval between completion of - metrics fetching. This closely approximates the duration of `num_steps`, - as the steps must be computes sequentually, and it is more accurate than - computing the time since the call to the step function since its actual - execution occurs asynchronously on the TPU/GPU device. - - Args: - metrics: acculumated metric values. - step: the current train step. - num_steps: the number of steps the metrics are accumulated across. - - Returns: - A mapping of name -> scalar value of the written summary. Only return the - real scalar value on host 0. For other hosts, return None. - """ - step = utils.get_local_data(step) - - # Must be called in the main thread to avoid race condition. - duration_future = self._duration_timer.stop(block_on=metrics) - - def _summarize_and_write(): - # For thread safety we first copy the metrics to host. - fetched_metrics = jax.tree_util.tree_map(jax.device_get, metrics) - - duration = duration_future.result() - # We set the duration on time-related metrics. - final_metrics = metrics_lib.set_time_metrics_duration( - fetched_metrics, duration - ) - # Set num_steps for Step metrics (AveragePerStep, StepsPerTime, ...) - final_metrics = metrics_lib.set_step_metrics_num_steps( - final_metrics, num_steps - ) - - # Ensure the metrics are not on device, which could lead to a deadlock. - def _ensure_not_on_device(x): - assert not isinstance(x, jax.Array) - - jax.tree_util.tree_map(_ensure_not_on_device, final_metrics) - final_metrics = jax.tree_util.tree_map( - utils.get_local_data, final_metrics - ) - - summary = {k: v.compute_value() for k, v in final_metrics.items()} - with self._writer_lock: - metric_writers.write_values(self._writer, int(step), summary) - - return summary - - return self._summary_pool(_summarize_and_write)() - - def flush(self): - try: - self._summary_pool.join() - finally: - if self._writer: - self._writer.flush() - - -class PreemptionError(Exception): - """Training has been interrupted and needs an emergency checkpoint.""" - - -class BaseTrainer(abc.ABC): - """Abstract base trainer class. - - Internally this uses MetricsManagers that start threads. You should - use the trainer as a context manager, or call close() directly in - order to wait for these threads to finish after training is done. - """ - - def __init__( - self, - model: models.BaseModel, - train_state: train_state_lib.TrainState, - partitioner: partitioning.BasePartitioner, - eval_names: Sequence[str], - summary_dir: Optional[str], - train_state_axes: Any, - rng: Rng, - ): - """Trainer constructor. - - Args: - model: the instantiation of `BaseModel` to train. - train_state: A train state with model parameters and optimizer state. - partitioner: the partitioner to use. - eval_names: names of evaluation datasets, which must match the keys of the - mapping passed to `eval`. - summary_dir: optional directory to write TensorBoard metrics to. - train_state_axes: partitioning info for the train state to be used. - rng: jax PRNGKey seed for random operations, to be combined with step - number for a deterministic RNG. - """ - jax.monitoring.record_event("/jax/t5x/train/beacon") - self._model = model - self._train_state_axes = train_state_axes - self._base_rng = rng - self._partitioner = partitioner - self._compiled_train_step: Optional[PartitionedTrainCallable] = None - self._compiled_eval_steps: MutableMapping[str, PartitionedEvalCallable] = {} - self._compiled_eval_step_cache: MutableMapping[ - BatchSpec, PartitionedEvalCallable - ] = {} - - self._train_state_mutex = threading.RLock() - self._train_state = train_state - - self.stop_training = False - - # Time since the trainer was made, this will record the "uptime" of the job. - self._trainer_init_time = _time() - - # The training metrics combine metrics added by the Model (e.g., loss and - # accuracy) and Trainer (e.g., learning rate). - self.train_metrics_manager = MetricsManager( - "train", summary_dir=summary_dir - ) - - # The eval metrics only include metrics added by the Model. - self.eval_metrics_managers = { # pylint:disable=g-complex-comprehension - n: MetricsManager(f"training_eval/{n[:113]}", summary_dir=summary_dir) - for n in eval_names - } - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_value, traceback): - self.close() - - def close(self): - """Stops all train metric managers threads.""" - self.train_metrics_manager.close() - for mm in self.eval_metrics_managers.values(): - mm.close() - - def _get_step_rng(self, step: int) -> Rng: - return jax.random.fold_in(self._base_rng, step) - - @property - def train_state(self): - with self._train_state_mutex: - return self._train_state - - @train_state.setter - def train_state(self, train_state: PyTree): - with self._train_state_mutex: - self._train_state = train_state - - def train( - self, - batch_iter: Union[ - Iterator[BatchType], clu.data.dataset_iterator.DatasetIterator - ], - num_steps: int, - start_step: Optional[int] = None, - ) -> ArrayMapFuture: - """Runs the train loop for the given number of steps.""" - metrics = None - # Use pre-compiled step, if available. - train_step_fn = self._compiled_train_step or self._partitioned_train_step - - # We lock `train_state` access during the loop to avoid race conditions. - with self._train_state_mutex: - train_state = self.train_state - # Compute step number on host to avoid communication overhead. - start_step = int( - start_step if start_step is not None else train_state.step - ) - self.train_metrics_manager.start_duration_timer(block_on=train_state) - for step_num in range(start_step, start_step + num_steps): - logging.log_every_n_seconds( - logging.INFO, "Training: step %d", 10, step_num - ) - with jax.profiler.StepTraceAnnotation("train", step_num=step_num): - batch = next(batch_iter) - train_state, metrics_update = train_step_fn(train_state, batch) - if metrics: - metrics = merge_metrics(metrics, metrics_update) - else: - metrics = metrics_update - - self.train_state = train_state - - if metrics is not None: - metrics["timing/uptime"] = clu.metrics.LastValue.from_model_output( - jnp.asarray([_time() - self._trainer_init_time]) - ) - - return self.train_metrics_manager.write_metrics_summary( - metrics, start_step + num_steps, num_steps - ) - - def compile_train(self, batch: BatchType) -> None: - """Pre-compiles train step (if not yet compiled). - - Not required. - - If not called before `train`, compilation will occur automatically on the - first step and JAX's "jit cache" will be used to avoid recompilation for - future steps. - - Args: - batch: A sample batch that may contain dummy values, but with correct - shapes and dtypes. - """ - tick = _time() - self._compiled_train_step = self._partitioner.compile( - self._partitioned_train_step, self.train_state, batch - ) - tock = _time() - self.train_metrics_manager.write_scalar( - "timing/compilation_seconds", # pytype: disable=wrong-arg-types # jax-ndarray - tock - tick, - self.train_state.step, - ) - - def eval( - self, batch_iters: Mapping[str, Iterator[BatchType]] - ) -> Mapping[str, Array]: - """Runs evaluation loop over the iterator and writes summary.""" - eval_summaries = {} - train_state = self.train_state - for iter_name, batch_iter in batch_iters.items(): - logging.info("Evaluating: %s.", iter_name) - metrics = None - # Use a pre-compiled step function, if available. - eval_step_fn = self._compiled_eval_steps.get( - iter_name, self._partitioned_eval_step - ) - mm = self.eval_metrics_managers[iter_name] - - num_steps = 0 - mm.start_duration_timer(block_on=train_state) - for batch in batch_iter: - num_steps += 1 - utils.multihost_assert_equal( - jnp.array(num_steps), - "Eval step mismatch across hosts. Check for empty dataset shard.", - ) - if jax.process_count() > 1: - batch = partitioning.host_local_array_to_global_array( - batch, - self._partitioner.mesh, - self._partitioner.data_partition_spec, - ) - metrics_update = eval_step_fn(train_state, batch) - if metrics: - metrics = merge_metrics(metrics, metrics_update) - else: - metrics = metrics_update - utils.multihost_assert_equal( - jnp.array(-1), - "Eval step mismatch across hosts. Check for empty dataset shard.", - ) - - eval_summaries[iter_name] = mm.write_metrics_summary( # pytype: disable=wrong-arg-types # jax-ndarray - metrics, train_state.step, num_steps - ) - - # TODO(adarob): Return futures. - return {k: v.result() for k, v in eval_summaries.items()} - - def compile_eval(self, batches: Mapping[str, BatchType]) -> None: - """Pre-compiles eval step (if not yet compiled). - - Not required. - - Pre-compiles the evaluation step for each evaluation dataset, reusing cached - compilations where possible. In other words, if multiple evaluation datasets - have equivalent shapes/dtypes for the batch and initial metrics, - recompilation will be avoided. - - If not called before `eval`, compilation will occur automatically on the - first step and JAX's "jit cache" will be used to avoid recompilation for - future steps. - - Args: - batches: a mapping from evaluation dataset name to a sample batch. The - batch may contain dummy values, but the shapes and dtypes must be - correct. - """ - for eval_name, batch in batches.items(): - tick = _time() - cache_key: BatchSpec = FrozenDict(jax.eval_shape(lambda: batch)) # pylint:disable=cell-var-from-loop - if cache_key not in self._compiled_eval_step_cache: - if jax.process_count() > 1: - batch = partitioning.host_local_array_to_global_array( - batch, - self._partitioner.mesh, - self._partitioner.data_partition_spec, - ) - self._compiled_eval_step_cache[cache_key] = self._partitioner.compile( - self._partitioned_eval_step, self.train_state, batch - ) - self._compiled_eval_steps[eval_name] = self._compiled_eval_step_cache[ - cache_key - ] - tock = _time() - self.eval_metrics_managers[eval_name].write_scalar( # pytype: disable=wrong-arg-types # jax-ndarray - "timing/compilation_seconds", tock - tick, self.train_state.step - ) - - @property - @abc.abstractmethod - def _partitioned_train_step(self) -> PartitionedTrainCallable: - """Partitioned train step.""" - raise NotImplementedError - - @property - @abc.abstractmethod - def _partitioned_eval_step(self) -> PartitionedEvalCallable: - """Partitioned eval step.""" - raise NotImplementedError - - -def accumulate_grads_microbatched( - model: models.BaseModel, - train_state: train_state_lib.TrainState, - batch: BatchType, - dropout_rng: Rng, - num_microbatches: Optional[int], - data_partition_spec: PartitionSpec = PartitionSpec("data"), -) -> Tuple[ - train_state_lib.TrainState, MutableMetricMapType, Optional[FlaxMutables] -]: - """Implements optional microbatched gradient accumulation. - - Args: - model: the instantiation of `BaseModel` to train. - train_state: A train state with model parameters and optimizer state. - batch: input batch consisting of either - simply-padded batched features - 'encoder_input_tokens', 'decoder_input_tokens' 'decoder_target_tokens' - 'decoder_loss_weights'- packed, batched features with additional - "(encoder|decoder)_segment_id", "(encoder|decoder)_position" - dropout_rng: jax PRNGKey for dropout. - num_microbatches: the number of microbatches to use, or None for direct - training. - data_partition_spec: the PartitionSpec to use for partitioning annotations - on the batch. - - Returns: - Accumulated gradients and incremental metrics. - """ - batch_size = next(iter(batch.values())).shape[0] - - grad_fn = jax.value_and_grad(model.loss_fn, has_aux=True) - - # We assume that the model loss_fn supports flax mutables if and only if - # the train state includes non-empty flax mutables. - # Note: Default t5x models don't support flax_mutables. One needs to subclass - # them and return flax_mutables from `get_initial_variables` and `loss_fn`. - - initial_flax_mutables = ( - train_state.flax_mutables if train_state.flax_mutables else None - ) - - if num_microbatches is None or num_microbatches <= 1: - - if initial_flax_mutables is None: - (_, metrics), grad_accum = grad_fn(train_state.params, batch, dropout_rng) - flax_mutables = None - else: - (_, (metrics, flax_mutables)), grad_accum = grad_fn( - train_state.params, batch, dropout_rng, initial_flax_mutables - ) - else: - assert ( - batch_size % num_microbatches == 0 - ), "Batch size isn't divided evenly by num_microbatches." - microbatch_size = batch_size // num_microbatches - logging.info( - "using microbatches: %d microbatches, %d size", - num_microbatches, - microbatch_size, - ) - - def get_microbatch(batch: BatchType, idx: int) -> Mapping[str, jnp.ndarray]: - """Fetch microbatch slice from possibly-packed input data.""" - offset = idx * microbatch_size - length = microbatch_size - starts = {k: [offset] + [0] * (b.ndim - 1) for k, b in batch.items()} - limits = {k: [length] + list(b.shape[1:]) for k, b in batch.items()} - return { - k: jax.lax.dynamic_slice(b, starts[k], limits[k]) - for k, b in batch.items() - } - - def metrics_and_grad(loop_cnt, dropout_rng, flax_mutables=None): - dropout_rng, sub_dropout_rng = jax.random.split(dropout_rng) - mbatch = get_microbatch(batch, loop_cnt) - # We need to annotate the microbatch sharding as we would a batch. - mbatch = jax.tree_util.tree_map( - lambda x: partitioning.with_sharding_constraint( # pylint: disable=g-long-lambda - x, data_partition_spec - ), - mbatch, - ) - if flax_mutables is None: - (_, metrics), grad = grad_fn( - train_state.params, mbatch, sub_dropout_rng - ) - else: - (_, (metrics, flax_mutables)), grad = grad_fn( - train_state.params, mbatch, sub_dropout_rng, flax_mutables - ) - return metrics, grad, flax_mutables - - def per_microbatch_train_step( - loop_cnt: int, - state: Tuple[ - jnp.ndarray, - jnp.ndarray, - Mapping[str, jnp.ndarray], - Optional[FlaxMutables], - ], - ) -> Tuple[ - jnp.ndarray, - jnp.ndarray, - Mapping[str, jnp.ndarray], - Optional[FlaxMutables], - ]: - (dropout_rng, grad_accum, prev_metrics, flax_mutables) = state - metrics, grad, flax_mutables = metrics_and_grad( - loop_cnt, dropout_rng, flax_mutables - ) - - grad_accum = jax.tree_util.tree_map(jnp.add, grad_accum, grad) - metrics = jax.lax.cond( - loop_cnt == 0, - lambda _: metrics, - lambda _: merge_metrics(prev_metrics, metrics), - None, - ) - return dropout_rng, grad_accum, metrics, flax_mutables - - # Initialize gradient accumulation loop state. - accum_dtype = jnp.float32 - grad_accum_init = jax.tree_util.tree_map( - lambda x: jnp.zeros(x.shape, accum_dtype), train_state.params - ) - initial_metrics_shape, _, _ = jax.eval_shape( - metrics_and_grad, - loop_cnt=0, - dropout_rng=dropout_rng, - flax_mutables=initial_flax_mutables, - ) - - initial_metrics = { - k: metrics_lib.shape_obj_to_defined_obj(v) - for k, v in initial_metrics_shape.items() - } - loop_init = ( - dropout_rng, - grad_accum_init, - initial_metrics, - initial_flax_mutables, - ) - new_dropout_rng, grad_accum, metrics, flax_mutables = jax.lax.fori_loop( - 0, num_microbatches, per_microbatch_train_step, loop_init - ) - - del new_dropout_rng - - return grad_accum, metrics, flax_mutables - - -def apply_grads( - train_state: train_state_lib.TrainState, - grad_accum: ModelWeights, - metrics: MutableMetricMapType, - learning_rate: jnp.ndarray, - weight_metrics_computer: Optional[WeightMetricsComputer], - other_state_variables: Optional[Mapping[str, Any]] = None, -) -> Tuple[train_state_lib.TrainState, MetricMapType]: - """Applies gradients to the optimizer. - - Args: - train_state: A train state that contains model and optimizer params. - grad_accum: results of `accumulate_grads` call. - metrics: incremental metrics from `accumulate_grads` call. - learning_rate: the learning rate to use for this step. - weight_metrics_computer: A WeightMetricsComputer instance, or None, to - decide what metrics, if any, to log about weights and weight updates - during training. - other_state_variables: other variables to update the state with. - - Returns: - The updated train state, metrics. - """ - if other_state_variables is None: - other_state_variables = {} - # Update optimizer using accumulated gradient. - new_train_state = train_state.apply_gradient( - grad_accum, learning_rate=learning_rate, **other_state_variables - ) - metrics["learning_rate"] = clu.metrics.Average.from_model_output( - jnp.asarray([learning_rate]) - ) - metrics["learning_rate/current"] = clu.metrics.LastValue.from_model_output( - jnp.asarray([learning_rate]) - ) - if weight_metrics_computer is not None: - metrics.update( - weight_metrics_computer.compute_metrics( - grad_accum, train_state, new_train_state - ) - ) - return new_train_state, metrics - - -def eval_step( - model: models.BaseModel, - train_state: train_state_lib.TrainState, - batch: jnp.ndarray, -) -> MetricMapType: - """Default evaluation step.""" - if not train_state.flax_mutables: - _, metrics = model.eval_fn(train_state.params, batch) # pytype: disable=wrong-arg-types # jax-ndarray - else: - # If the training state contains mutable variables, then we expect the - # model to accept this extra arguments in the eval function. - # pytype: disable=wrong-arg-count - # pytype: disable=wrong-arg-types - _, metrics = model.eval_fn( - train_state.params, batch, train_state.flax_mutables - ) - # pytype: enable=wrong-arg-count - # pytype: enable=wrong-arg-types - return metrics - - -def train_with_lr( - train_state: train_state_lib.TrainState, - batch: BatchType, - learning_rate: jnp.ndarray, - dropout_rng: Rng, - model: models.BaseModel, - num_microbatches: Optional[int], - weight_metrics_computer: Optional[WeightMetricsComputer] = None, - data_partition_spec: PartitionSpec = PartitionSpec("data"), -): - """Main training function with LR schedule.""" - grad_accum, metrics, flax_mutables = accumulate_grads_microbatched( - model, - train_state, - batch, - dropout_rng, - num_microbatches, - data_partition_spec, - ) - new_train_state, metrics = apply_grads( - train_state, - grad_accum, - metrics, - learning_rate, - weight_metrics_computer, - other_state_variables={"flax_mutables": flax_mutables} - if flax_mutables - else None, - ) - - return new_train_state, metrics - - -class BaseTrainerConstructor(Protocol): - """A function that returns a BaseTrainer.""" - - def __call__( - self, - model: models.BaseModel, - train_state: train_state_lib.TrainState, - partitioner: partitioning.BasePartitioner, - eval_names: Sequence[str], - summary_dir: Optional[str], - train_state_axes: Any, - rng: Rng, - ) -> BaseTrainer: - ... - - -class Trainer(BaseTrainer): - """Training loop with optional microbatches.""" - - def __init__( - self, - model: models.BaseModel, - train_state: train_state_lib.TrainState, - partitioner: partitioning.BasePartitioner, - eval_names: Sequence[str], - summary_dir: Optional[str], - train_state_axes: Any, - rng: Rng, - learning_rate_fn: LearningRateCallable, - num_microbatches: Optional[int], - weight_metrics_computer: Optional[WeightMetricsComputer] = None, - ): - """Trainer constructor. - - Args: - model: the instantiation of `BaseModel` to train. - train_state: a train state with parameters and optimizer state. - partitioner: the partitioner to use. - eval_names: names of evaluation datasets, which must match the keys of the - mapping passed to `eval`. - summary_dir: optional directory to write TensorBoard metrics to. - train_state_axes: partitioning info for the optimizer to be used. - rng: jax PRNGKey seed for random operations, to be combined with step - number for a deterministic RNG. - learning_rate_fn: returns the learning rate given the current step. - num_microbatches: the number of microbatches to use, or None for direct - training. - weight_metrics_computer: A WeightMetricsComputer instance, or None, to - decide what metrics, if any, to log about weights and weight updates - during training. - """ - self._learning_rate_fn = learning_rate_fn - self._num_microbatches = num_microbatches - self._weight_metrics_computer = weight_metrics_computer - - super().__init__( - model=model, - train_state=train_state, - partitioner=partitioner, - eval_names=eval_names, - summary_dir=summary_dir, - train_state_axes=train_state_axes, - rng=rng, - ) - - @cached_property - def _partitioned_train_step(self) -> PartitionedTrainCallable: - def train_step(train_state: train_state_lib.TrainState, batch: BatchType): - return train_with_lr( - train_state, - batch, - learning_rate=self._learning_rate_fn(train_state.step), - dropout_rng=self._get_step_rng(train_state.step), # pytype: disable=wrong-arg-types # jax-ndarray - model=self._model, - num_microbatches=self._num_microbatches, - weight_metrics_computer=self._weight_metrics_computer, - data_partition_spec=self._partitioner.data_partition_spec, - ) - - return self._partitioner.partition( - train_step, - in_axis_resources=( - self._train_state_axes, - self._partitioner.data_partition_spec, - ), - out_axis_resources=(self._train_state_axes, None), - donate_argnums=(0,), - ) - - @cached_property - def _partitioned_eval_step(self) -> PartitionedEvalCallable: - return self._partitioner.partition( - lambda *args, **kwargs: eval_step(self._model, *args, **kwargs), - in_axis_resources=( - self._train_state_axes, - self._partitioner.data_partition_spec, - ), - out_axis_resources=None, - ) - - -def _warn_action_not_run(action, task, metric): - logging.warning( - "The action: %s that tracks metric: %s for task: %s is not run", - action, - metric, - task, - ) - - -# TODO(b/200701930): Support dynamic registration for enum. -@enum.unique -class ActionMode(enum.Enum): - """Defines when to run a action. - - For example, TRAIN means to run an action after a TRAIN loop is done. - """ - - TRAIN = 1 - TRAIN_EVAL = 2 - INFER_EVAL = 3 - - -class BaseAction(abc.ABC): - """Base Action class for override. The action itself does nothing.""" - - @abc.abstractmethod - def run( - self, - train_state: train_state_lib.TrainState, - metrics_by_task: Mapping[str, MetricValueMapType], - ) -> bool: - """Runs an action for the given train_state and metrics. - - Args: - train_state: The current train_state in the training loop. - metrics_by_task: A map of metrics that is grouped by each task. - - Returns: - A bool indicating whether training should be halted. - """ - raise NotImplementedError("Action must define its run method.") - - -ActionMapType = Mapping[ActionMode, Sequence[BaseAction]] - - -class EarlyStoppingAction(BaseAction): - """Terminates training when the specified metric is not improving. - - Checks whether the monitored metrics are decreasing after every `train` or - `eval`, or `both`. If the loss is no longer decreasing for `patience` times, - terminate the training process. - """ - - def __init__( - self, - metric: Tuple[str, str], - mode: str, - patience: int = 3, - atol: float = 0.0, - rtol: float = 0.0, - ): - """Constructs the EarlyStoppingAction. - - Args: - metric: A metric to monitor when invoking the action. When the metric does - not improve for a number of times (specified in patience), stop the - training. The tuple takes 2 strings, whereas the first string defines - the task to track, and the second defines the metric of the task to - track. e.g.,: ('mt5_xnli_dev_test.all_langs', 'accuracy') would monitor - the 'accuracy' for `mt5_xnli_dev_test.all_langs`. - mode: One of `{"min", "max"}`. In `min` mode, training will stop when the - quantity monitored has stopped decreasing; in `"max"` mode it will stop - when the quantity monitored has stopped increasing; - patience: The threshold of stopping criteria. Usually this is measured by - number of steps. - atol: Absolute tolerance in the monitored quantity to qualify as an - improvement, i.e. a change of less than `atol`, will count as no - improvement. - rtol: Relative tolerance in the monitoried quantity to qualify as an - improvement. This combined with `atol` defines whether a change is - considered improvement. The total change is calculated as following: - `delta = (atol + rtol * previous)` See `numpy.allclose` for detailed - information. - """ - self._task, self._metric = metric - if mode not in ["min", "max"]: - raise ValueError('mode must be in ["min", "max"]') - self._mode = mode - - if atol < 0: - raise ValueError("atol must be greater equal than 0") - self._atol = atol - - if rtol < 0: - raise ValueError("rtol must be greater equal than 0") - self._rtol = rtol - - self._patience = patience - self._metric_history = [] - - def _compare_fn(self, current, previous): - compare_fn = jnp.greater_equal if self._mode == "min" else jnp.less_equal - delta = self._atol + self._rtol * abs(previous) - if self._mode == "max": - delta *= -1 - return compare_fn(current, previous - delta) - - def run( - self, - train_state: train_state_lib.TrainState, - metrics_by_task: Mapping[str, MetricValueMapType], - ) -> bool: - if self._task not in metrics_by_task.keys(): - logging.warning( - "Monitoring task: %s does not exist in all task metrics. " - "Available tasks are : %s", - self._task, - metrics_by_task.keys(), - ) - _warn_action_not_run(type(self), self._task, self._metric) - return False - if self._metric not in metrics_by_task[self._task].keys(): - logging.warning( - "Metric : %s does not exist in metrics for task : %s", - self._metric, - self._task, - ) - _warn_action_not_run(type(self), self._task, self._metric) - return False - - m = metrics_by_task[self._task][self._metric] - - if isinstance(m, clu.values.Scalar): - self._metric_history.append(m.value) - - # For metrics returned from action_mode=INFER_EVAL (i.e. seqio.Evaluator) - elif isinstance(m, float): - self._metric_history.append(m) - else: - logging.warning( - "Metric %s does not have Scalar type. Found %s.", - self._metric, - type(m), - ) - _warn_action_not_run(type(self), self._task, self._metric) - return False - - # Not enough history. - if len(self._metric_history) < self._patience: - return False - - if all( - self._compare_fn(self._metric_history[i + 1], self._metric_history[i]) - for i in range(self._patience - 1) - ): - logging.warning( - "Requested `stop_training` in training loop (Details below).\n " - "Metric: %s for Task: %s has not improved for %s iterations, detail " - "history of the metric: %s", - self._metric, - self._task, - self._patience, - self._metric_history, - ) - return True - # Remove extra histories that we don't need to keep. - self._metric_history.pop(0) - return False - - -class TerminateOnNanAction(BaseAction): - """Terminates training when NaN loss is detected. - - Checks whether the loss metric for the given task is NaN or Inf and terminates - training if so. - """ - - def __init__(self, task: str, metric: str = "loss"): - """Constructs the TerminateOnNanAction. - - Args: - task: Defines the task from which to track the given metric. - metric: Defines a metric to track for NaN values (defaults to "loss"). - """ - self._task = task - self._metric = metric - - def run( - self, - train_state: train_state_lib.TrainState, - metrics_by_task: Mapping[str, MetricValueMapType], - ) -> bool: - if self._task not in metrics_by_task.keys(): - logging.warning( - "Monitoring task: %s does not exist in all task metrics. " - "Available tasks are : %s", - self._task, - metrics_by_task.keys(), - ) - _warn_action_not_run(type(self), self._task, self._metric) - return False - if self._metric not in metrics_by_task[self._task].keys(): - logging.warning( - "Metric : %s does not exist in metrics for task : %s", - self._metric, - self._task, - ) - _warn_action_not_run(type(self), self._task, self._metric) - return False - - metric = metrics_by_task[self._task][self._metric] - - if not isinstance(metric, clu.values.Scalar): - logging.warning( - "Metric %s does not have Scalar type. Found %s.", - self._metric, - type(metric), - ) - _warn_action_not_run(type(self), self._task, self._metric) - return False - - value = metric.value - if np.isnan(value) or np.isinf(value): - logging.warning( - "Requested `stop_training` in training loop (Details below).\n " - "NaN encountered in metric for task : %s", - self._task, - ) - return True - - return False diff --git a/t5x-main/t5x/trainer_test.py b/t5x-main/t5x/trainer_test.py deleted file mode 100644 index ad9ce62d30c02d2f0ff993796f161de998224bd9..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/trainer_test.py +++ /dev/null @@ -1,1148 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for t5x.trainer_lib.""" - -import collections -import os - -from absl.testing import absltest -from absl.testing import parameterized -import chex -from clu import metric_writers -import clu.metrics -import clu.values -import flax -import jax -import jax.numpy as jnp -import numpy as np -from t5x import metrics as metrics_lib -from t5x import models as models_lib -from t5x import optimizers -from t5x import partitioning -from t5x import test_utils -from t5x import train_state as train_state_lib -from t5x import trainer as trainer_lib -import tensorflow as tf -from tensorflow.io import gfile - -mock = absltest.mock -jax.config.parse_flags_with_absl() - -FlaxMutables = flax.core.FrozenDict - - -def _validate_events(test_case, summary_dir, expected_metrics, steps): - summaries = gfile.listdir(summary_dir) - test_case.assertLen(summaries, 1) - summary_path = os.path.join(summary_dir, summaries[0]) - event_file = os.path.join(summary_path) - events = list(tf.compat.v1.train.summary_iterator(event_file)) - actual_events = {} - # First event is boilerplate - test_case.assertLen(events, len(steps) + 1) - for step, event in zip(steps, events[1:]): - test_case.assertEqual(event.step, step) - test_case.assertLen(event.summary.value, 1) - tensor = event.summary.value[0].tensor - if tensor.string_val: - actual_events[event.summary.value[0].tag] = tensor.string_val[0].decode() - else: - actual_events[event.summary.value[0].tag] = float(tf.make_ndarray(tensor)) - - jax.tree.map(test_case.assertAlmostEqual, actual_events, expected_metrics) - - -class MetricsManagerTest(absltest.TestCase): - - def setUp(self): - super().setUp() - self.model_dir = self.create_tempdir().full_path - - def test_summary_dir(self): - # All hosts have the summary dir. - with mock.patch('jax.process_index', return_value=0): - mm = trainer_lib.MetricsManager('eval', self.model_dir) - self.assertEqual(mm.summary_dir, os.path.join(self.model_dir, 'eval')) - mm.close() - - with mock.patch('jax.process_index', return_value=1): - mm = trainer_lib.MetricsManager('eval', self.model_dir) - self.assertEqual(mm.summary_dir, os.path.join(self.model_dir, 'eval')) - mm.close() - - def test_summary_writer(self): - # Only host 0 creates a non-empty summary writer. - with mock.patch('jax.process_index', return_value=1): - mm = trainer_lib.MetricsManager('eval', self.model_dir) - self.assertFalse(gfile.exists(mm.summary_dir)) - mm.close() - - with mock.patch('jax.process_index', return_value=0): - mm = trainer_lib.MetricsManager('eval', self.model_dir) - self.assertIsInstance(mm.summary_writer, metric_writers.MetricWriter) - self.assertTrue(gfile.exists(mm.summary_dir)) - mm.close() - - def test_write_scalar(self): - gfile.makedirs(os.path.join(self.model_dir, 'eval')) - - # tag, value, step - scalars = [('loss', 1.0, 1), ('accuracy', 100.0, 2)] - - # Only host 0 has actually writes summaries. - with mock.patch('jax.process_index', return_value=1): - mm = trainer_lib.MetricsManager('eval', self.model_dir) - for s in scalars: - mm.write_scalar(*s) - self.assertEmpty(gfile.listdir(mm.summary_dir)) - mm.close() - - with mock.patch('jax.process_index', return_value=0): - mm = trainer_lib.MetricsManager('eval', self.model_dir) - for s in scalars: - mm.write_scalar(*s) - mm.flush() - - summaries = gfile.listdir(mm.summary_dir) - self.assertLen(summaries, 1) - - event_file = os.path.join(mm.summary_dir, summaries[0]) - events = list(tf.compat.v1.train.summary_iterator(event_file)) - # First event is boilerplate - self.assertLen(events, 3) - for event, (tag, value, step) in zip(events[1:], scalars): - self.assertEqual(event.step, step) - self.assertLen(event.summary.value, 1) - self.assertEqual(event.summary.value[0].tag, tag) - self.assertEqual(tf.make_ndarray(event.summary.value[0].tensor), value) - mm.close() - - def test_write_metrics_summary(self): - gfile.makedirs(os.path.join(self.model_dir, 'eval')) - - @flax.struct.dataclass - class MockTextMetric(clu.metrics.Metric): - - def compute_value(self): - return clu.values.Text('test metric') - - accumulated_metrics = { - 'loss': metrics_lib.Sum(40.0), - 'accuracy': metrics_lib.AveragePerStep.from_model_output(20.0), - 'steps_per_second': metrics_lib.StepsPerTime(), - 'text': MockTextMetric(), - } - expected_values = { - 'loss': clu.values.Scalar(40.0), - 'accuracy': clu.values.Scalar(10.0), - 'steps_per_second': clu.values.Scalar(0.05), - 'text': clu.values.Text('test metric'), - } - with mock.patch('jax.process_index', return_value=0), mock.patch( - 't5x.trainer._time', side_effect=[0, 40] # start_time, end_time - ): - mm = trainer_lib.MetricsManager('eval', summary_dir=self.model_dir) - mm.start_duration_timer() - summary = mm.write_metrics_summary( - accumulated_metrics, step=4, num_steps=2 - ) - mm.flush() - - self.assertDictEqual(summary.result(), expected_values) - _validate_events( - self, - mm.summary_dir, - {k: v.value for k, v in expected_values.items()}, - steps=[4, 4, 4, 4], - ) - - mm.close() - - def test_timer_blocking_on_donated_buffer(self): - mm = trainer_lib.MetricsManager('train', summary_dir=None) - x = jnp.zeros(1) - - # Not deleted. - mm.start_duration_timer(block_on=x) - mm._duration_timer._start_future.result() - - # Deleted/donated. - x.addressable_data(0).delete() - mm.start_duration_timer(block_on=x) - mm._duration_timer._start_future.result() - - def test_timer_concurrency(self): - mm = trainer_lib.MetricsManager('train') - - n = 10 - with mock.patch( - 't5x.trainer._time', - side_effect=range(2 * n), # start_time, end_time - ): - for _ in range(n): - mm.start_duration_timer() - summary = mm.write_metrics_summary({'time': metrics_lib.Time()}, 0, 1) - self.assertEqual(1, summary.result()['time'].value) - mm.flush() - - -def fake_accum_grads( - model, optimizer, batch, rng, num_microbatches, data_partition_spec -): - del model, num_microbatches, rng, data_partition_spec - # Add `i` to each optimzer value. - i = batch['i'].sum() - grad_accum = jax.tree.map(lambda x: i, optimizer) - # Add j to each metric. - j = batch['j'].sum() - metrics = {'loss': metrics_lib.Sum(j), 'accuracy': metrics_lib.Sum(j)} - return grad_accum, metrics, None - - -def fake_apply_grads( - optimizer, - grad_accum, - metrics, - learning_rate, - weight_metrics_computer, - other_state_variables=None, -): - del weight_metrics_computer - del other_state_variables - metrics['learning_rate'] = clu.metrics.Average(learning_rate, count=1) - optimizer = jax.tree.map(lambda x, g: x + g, optimizer, grad_accum) - return optimizer, metrics - - -def fake_eval_step(model, optimizer, batch): - del model, optimizer - # Add `i` to each metric. - i = batch['i'].sum() - - return {'loss': metrics_lib.Sum(i), 'accuracy': metrics_lib.Sum(i)} - - -def fake_eval_fn_without_weight_sum(params, batch): - del params - # Add `i` to each metric. - i = batch['i'].sum() - - loss = metrics_lib.Sum(i) - return loss, {'loss': loss, 'accuracy': metrics_lib.Sum(i)} - - -def fake_eval_fn_with_mutables(params, batch, flax_mutables): - assert flax_mutables is not None - del flax_mutables - del params - # Add `i` to each metric. - i = batch['i'].sum() - loss = metrics_lib.Sum(i) - return loss, {'loss': loss, 'accuracy': metrics_lib.Sum(i)} - - -def build_fake_grad_fn_without_weight_sum(has_aux, require_flax_mutables): - - def fake_grad_fn_without_weight_sum( - train_state_params, batch, dropout_rng, flax_mutables=None - ): - del dropout_rng, train_state_params - # Add `i` to each optimzer value. - i = batch['i'].sum() - optimizer = optimizers.Optimizer( - optimizers.sgd(0.1), - state=optimizers.OptimizerState( - step=0, param_states={'bias': 0, 'kernel': 0} - ), - target={'bias': np.zeros(4), 'kernel': np.zeros((2, 4))}, - ) - train_state = train_state_lib.FlaxOptimTrainState(optimizer) - grad_accum = jax.tree.map(lambda x: i, train_state) - # Add j to each metric. - j = batch['j'].sum() - metrics = {'loss': metrics_lib.Sum(j), 'accuracy': metrics_lib.Sum(j)} - - if require_flax_mutables or flax_mutables is not None: - aux = metrics, flax_mutables - else: - aux = metrics - - if has_aux: - return (None, aux), grad_accum.params - else: - return None, grad_accum.params - - return fake_grad_fn_without_weight_sum - - -def fake_value_and_grad_fn_without_weight_sum(callable_fn, has_aux=False): - del callable_fn - return build_fake_grad_fn_without_weight_sum(has_aux, False) - - -def fake_value_and_grad_fn_wo_weight_sum_w_mutables(callable_fn, has_aux=False): - del callable_fn - return build_fake_grad_fn_without_weight_sum(has_aux, True) - - -class TrainerTest(parameterized.TestCase): - - def setUp(self): - super().setUp() - self.init_optimizer = optimizers.Optimizer( - optimizers.sgd(0.1), - state=optimizers.OptimizerState( - step=0, param_states={'bias': 0, 'kernel': 0} - ), - target={'bias': np.zeros(4), 'kernel': np.zeros((2, 4))}, - ) - self.init_train_state = train_state_lib.FlaxOptimTrainState( - self.init_optimizer - ) - train_state_axes = jax.tree.map(lambda x: None, self.init_train_state) - model_dir = self.create_tempdir().full_path - - mapfn = lambda i: {'i': [tf.cast(i, tf.int32)], 'j': [tf.cast(1, tf.int32)]} - self.dataset = ( - tf.data.Dataset.range(6).map(mapfn).batch(2, drop_remainder=True) - ) - - with mock.patch('t5x.trainer._time', side_effect=[0]): # trainer init - self.test_trainer = trainer_lib.Trainer( - mock.create_autospec(models_lib.BaseModel, instance=True), - self.init_train_state, - partitioning.PjitPartitioner(num_partitions=1), - eval_names=['task1', 'task2'], - summary_dir=model_dir, - train_state_axes=train_state_axes, - rng=np.ones(2, np.uint32), - learning_rate_fn=lambda step: 2 * step, - num_microbatches=None, - ) - - def tearDown(self) -> None: - self.test_trainer.close() - return super().tearDown() - - @mock.patch('t5x.trainer.accumulate_grads_microbatched', fake_accum_grads) - @mock.patch('t5x.trainer.apply_grads', fake_apply_grads) - def _test_train(self, precompile): - trainer = self.test_trainer - initial_rng = trainer._base_rng - - if precompile: - with mock.patch( - 't5x.trainer._time', side_effect=[0, 1] # compile start, end - ), mock.patch( - 'absl.logging.log' - ): # avoids hidden calls to time.time() - trainer.compile_train(next(self.dataset.as_numpy_iterator())) - trainer._compiled_train_step = mock.Mock( - side_effect=trainer._compiled_train_step - ) - - trainer._partitioned_train_step = mock.Mock( - side_effect=trainer._partitioned_train_step - ) - - num_steps = 2 - with mock.patch( - 't5x.trainer._time', - side_effect=[1, 5, 6], # start_time, uptime logged, end_time - ): - trainer.train(self.dataset.as_numpy_iterator(), num_steps).result() - - initial_metrics = { - 'loss': 0.0, - 'accuracy': 0.0, - } - expected_metrics = { - k: v + 2 * num_steps for k, v in initial_metrics.items() - } - # (0 + 2) / 2 = 1 - expected_metrics['learning_rate'] = 1 - # 5.0 - 0.0 - expected_metrics['timing/uptime'] = 5.0 - # 0+1+2+3 = 6 - expected_train_state = jax.tree.map( - lambda x: np.array(x + 6), self.init_train_state - ) - - # Base rng must remain the same - np.testing.assert_array_equal(trainer._base_rng, initial_rng) - jax.tree.map( - np.testing.assert_equal, trainer.train_state, expected_train_state - ) - # Expected step is 6 since we increment it along with the other optimizer - # values. - steps = [2, 2, 2, 2] - if precompile: - steps = [0] + steps - expected_metrics['timing/compilation_seconds'] = 1 - self.assertEqual(trainer._compiled_train_step.call_count, num_steps) - trainer._partitioned_train_step.assert_not_called() - else: - self.assertIsNone(trainer._compiled_train_step) - self.assertEqual(trainer._partitioned_train_step.call_count, num_steps) - trainer.train_metrics_manager.flush() - _validate_events( - self, - trainer.train_metrics_manager.summary_dir, - expected_metrics, - steps=steps, - ) - - def test_train_noprecompile(self): - self._test_train(False) - - def test_train_precompile(self): - self._test_train(True) - - @mock.patch('t5x.trainer.eval_step', fake_eval_step) - def _test_eval(self, precompile): - trainer = self.test_trainer - initial_rng = trainer._base_rng - - task_datasets = { - 'task1': self.dataset.take(2), - 'task2': self.dataset.repeat().take(5), - } - - if precompile: - # [task1 start, task1 end, task2 start, task2 end] - with mock.patch( - 't5x.trainer._time', - side_effect=[0, 1, 2, 3], # [t1 start, t1 end, t2 start, t2 end] - ), mock.patch( - 'absl.logging.log' - ): # avoids hidden calls to time.time() - trainer.compile_eval({ - task: next(ds.as_numpy_iterator()) - for task, ds in task_datasets.items() - }) - trainer._compiled_eval_steps = { - task: mock.Mock(side_effect=trainer._compiled_eval_steps[task]) - for task in task_datasets - } - - trainer._partitioned_eval_step = mock.Mock( - side_effect=trainer._partitioned_eval_step - ) - - with mock.patch( - 't5x.trainer._time', - side_effect=[1, 5, 5, 8], # t1 start, t1 end, t2 start, t2 end] - ): - trainer.eval( - {task: ds.as_numpy_iterator() for task, ds in task_datasets.items()} - ) - - all_expected_metrics = { - # 0+1+2+3 = 6 - 'task1': { - 'loss': 6, - 'accuracy': 6, - }, - # 0+1+2+3+4+5+0+1+2+3 = 21 - 'task2': { - 'loss': 21, - 'accuracy': 21, - }, - } - - np.testing.assert_array_equal(trainer._base_rng, initial_rng) - for task_name, expected_metrics in all_expected_metrics.items(): - steps = [0, 0] - if precompile: - steps = [0] + steps - expected_metrics['timing/compilation_seconds'] = 1 - self.assertEqual( # pylint:disable=g-generic-assert - trainer._compiled_eval_steps[task_name].call_count, - len(task_datasets[task_name]), - ) - trainer._partitioned_eval_step.assert_not_called() - else: - self.assertEmpty(trainer._compiled_eval_steps) - self.assertEqual( - trainer._partitioned_eval_step.call_count, - sum(len(ds) for ds in task_datasets.values()), - ) - mm = trainer.eval_metrics_managers[task_name] - mm.flush() - _validate_events(self, mm.summary_dir, expected_metrics, steps=steps) - - def test_eval_noprecompile(self): - self._test_eval(False) - - def test_eval_precompile(self): - self._test_eval(True) - - @parameterized.named_parameters([ - { - 'testcase_name': 'max_no_increase', - 'mode': 'max', - 'metrics': [1, 1, 1], - 'atol': 0.0, - 'rtol': 0.0, - 'stop_training': True, - }, - { - 'testcase_name': 'max_no_atol', - 'mode': 'max', - 'metrics': [1, 0.9, 0.8], - 'atol': 0.0, - 'rtol': 0.0, - 'stop_training': True, - }, - { - 'testcase_name': 'max_not_enough_atol', - 'mode': 'max', - 'metrics': [1, 1.09, 1.18], - 'atol': 0.1, - 'rtol': 0.0, - 'stop_training': True, - }, - { - 'testcase_name': 'max_enough_atol', - 'mode': 'max', - 'metrics': [1, 1.2, 1.4], - 'atol': 0.1, - 'rtol': 0.0, - 'stop_training': False, - }, - { - 'testcase_name': 'max_enough_atol_rtol', - 'mode': 'max', - # first delta = 0.1 + 1* 0.08 = 0.18 - # second delta = 0.1 + 1.2 * 0.08 = 0.196 - 'metrics': [1, 1.2, 1.4], - 'atol': 0.1, - 'rtol': 0.08, - 'stop_training': False, - }, - { - 'testcase_name': 'max_not_enough_rtol', - 'mode': 'max', - 'metrics': [1, 1.2, 1.4], - 'atol': 0.0, - 'rtol': 0.2, - 'stop_training': True, - }, - { - 'testcase_name': 'min_no_decrease', - 'mode': 'min', - 'metrics': [1, 1, 1], - 'atol': 0.0, - 'rtol': 0.0, - 'stop_training': True, - }, - { - 'testcase_name': 'min_no_atol', - 'mode': 'min', - 'metrics': [1, 1, 1], - 'atol': 0.0, - 'rtol': 0.0, - 'stop_training': True, - }, - { - 'testcase_name': 'min_not_enough_atol', - 'mode': 'min', - 'metrics': [1, 0.9, 0.71], - 'atol': 0.2, - 'rtol': 0.0, - 'stop_training': True, - }, - { - 'testcase_name': 'min_enough_atol', - 'mode': 'min', - 'metrics': [1, 0.8, 0.6], - 'atol': 0.15, - 'rtol': 0.0, - 'stop_training': False, - }, - { - 'testcase_name': 'min_enough_atol_rtol', - 'mode': 'min', - # first delta = 0.1 + 1* 0.09 = 0.19 - # second delta = 0.1 + 0.8 * 0.09 = 0.172 - 'metrics': [1, 0.8, 0.6], - 'atol': 0.1, - 'rtol': 0.09, - 'stop_training': False, - }, - { - 'testcase_name': 'min_not_enough_rtol', - 'mode': 'min', - 'metrics': [1, 0.8, 0.6], - 'atol': 0.0, - 'rtol': 0.3, - 'stop_training': True, - }, - { - 'testcase_name': 'longer_history', - 'mode': 'min', - 'metrics': [1, 0.8, 0.7, 0.6], - 'atol': 0.15, - 'rtol': 0.0, - 'stop_training': True, - }, - ]) - def test_early_stopping_action( - self, mode, metrics, atol, rtol, stop_training - ): - trainer = self.test_trainer - metrics = [clu.values.Scalar(metric) for metric in metrics] - hook = trainer_lib.EarlyStoppingAction( - ('test_task', 'metric'), mode=mode, patience=3, atol=atol, rtol=rtol - ) - - for metric in metrics: - trainer_stop_training = hook.run( - trainer.train_state, {'test_task': {'metric': metric}} - ) - - self.assertEqual(trainer_stop_training, stop_training) - - @parameterized.named_parameters([ - { - 'testcase_name': 'allow_clu_scalar_early_stopping', - 'metrics': [ - clu.values.Scalar(1), - clu.values.Scalar(0.9), - clu.values.Scalar(0.71), - ], - 'atol': 0.2, - 'stop_training': True, - }, - { - 'testcase_name': 'allow_float_early_stopping', - 'metrics': [1.0, 0.9, 0.71], - 'atol': 0.2, - 'stop_training': True, - }, - { - 'testcase_name': 'error_for_other_type', - 'metrics': [3, 2, 1], - 'atol': 1.1, - 'stop_training': False, - }, - ]) - def test_early_stopping_action_value(self, metrics, atol, stop_training): - trainer = self.test_trainer - hook = trainer_lib.EarlyStoppingAction( - ('test_task', 'metric'), mode='min', patience=3, atol=atol - ) - - for metric in metrics: - trainer_stop_training = hook.run( - trainer.train_state, {'test_task': {'metric': metric}} - ) - - self.assertEqual(trainer_stop_training, stop_training) - - @parameterized.named_parameters([ - { - 'testcase_name': 'invalid_task', - 'task': 'wrong_task', - 'metric': 'metric', - 'value': clu.values.Scalar(np.nan), - }, - { - 'testcase_name': 'invalid_metric_name', - 'task': 'task', - 'metric': 'wrong_metric_name', - 'value': clu.values.Scalar(np.nan), - }, - ]) - def test_early_stopping_action_error(self, task, metric, value): - trainer = self.test_trainer - hook = trainer_lib.EarlyStoppingAction( - (task, metric), mode='min', patience=5, atol=1, rtol=1 - ) - - trainer_stop_training = hook.run( - trainer.train_state, {task: {metric: value}} - ) - - self.assertFalse(trainer_stop_training) - - @parameterized.named_parameters([ - { - 'testcase_name': 'valid_loss', - 'metric': 'loss', - 'value': 1.0, - 'stop_training': False, - }, - { - 'testcase_name': 'nan', - 'metric': 'loss', - 'value': np.nan, - 'stop_training': True, - }, - { - 'testcase_name': 'inf', - 'metric': 'loss', - 'value': np.inf, - 'stop_training': True, - }, - { - 'testcase_name': 'other_metric', - 'metric': 'some_metric', - 'value': np.inf, - 'stop_training': True, - }, - ]) - def test_terminate_on_nan_action(self, metric, value, stop_training): - trainer = self.test_trainer - value = clu.values.Scalar(value) - hook = trainer_lib.TerminateOnNanAction(task='test_task', metric=metric) - - trainer_stop_training = hook.run( - trainer.train_state, {'test_task': {metric: value}} - ) - - self.assertEqual(trainer_stop_training, stop_training) - - @parameterized.named_parameters([ - { - 'testcase_name': 'invalid_task', - 'task': 'wrong_task', - 'metric': 'metric', - 'value': clu.values.Scalar(np.nan), - }, - { - 'testcase_name': 'invalid_metric_name', - 'task': 'task', - 'metric': 'wrong_metric_name', - 'value': clu.values.Scalar(np.nan), - }, - { - 'testcase_name': 'invalid_value', - 'task': 'task', - 'metric': 'metric', - 'value': 1.0, - }, - ]) - def test_terminate_on_nan_action_error(self, task, metric, value): - trainer = self.test_trainer - hook = trainer_lib.TerminateOnNanAction(task=task, metric=metric) - - trainer_stop_training = hook.run( - trainer.train_state, {'task': {'metric': value}} - ) - - self.assertFalse(trainer_stop_training) - - def test_compile_train(self): - trainer = self.test_trainer - trainer._partitioned_train_step = mock.Mock() - trainer.train_metrics_manager = mock.Mock() - - batch = { - 'i': np.arange(10, dtype=np.int32).reshape((2, 5)), - 'j': np.ones((), dtype=np.float32), - } - # compile start, compile end - with mock.patch('t5x.trainer._time', side_effect=[1, 5]): - trainer.compile_train(batch) - - trainer.train_metrics_manager.write_scalar.assert_called_with( - 'timing/compilation_seconds', 4, trainer.train_state.step - ) - trainer._partitioned_train_step.lower.assert_called_once() - train_step_args = trainer._partitioned_train_step.lower.call_args[0] - self.assertLen(train_step_args, 2) - self.assertEqual(train_step_args[0], trainer.train_state) - test_utils.assert_same(train_step_args[1], batch) - - def test_compile_eval(self): - trainer = self.test_trainer - trainer._partitioned_eval_step = mock.Mock() - trainer.eval_metrics_managers = { - 'eval1': mock.Mock(), - 'eval2': mock.Mock(), - 'eval3': mock.Mock(), - 'eval4': mock.Mock(), - } - trainer._partitioned_eval_step.lower().compile.side_effect = [ - 'compiled1', - 'compiled2', - 'compiled3', - ] - - batches = { - 'eval1': {'i': np.zeros((2, 5), dtype=np.int32)}, - 'eval2': {'j': np.zeros((), dtype=np.float32)}, - 'eval3': {'j': np.zeros((), dtype=np.float32)}, - 'eval4': {'k': np.zeros((4), dtype=np.float32)}, - } - - # eval1 start/end, eval2 start/end, eval3 start/end, eval 4 start/end - with mock.patch( - 't5x.trainer._time', side_effect=[1, 5, 6, 9, 10, 11, 12, 13] - ): - trainer.compile_eval(collections.OrderedDict(sorted(batches.items()))) - - trainer.eval_metrics_managers['eval1'].write_scalar.assert_called_with( - 'timing/compilation_seconds', 4, trainer.train_state.step - ) - trainer.eval_metrics_managers['eval2'].write_scalar.assert_called_with( - 'timing/compilation_seconds', 3, trainer.train_state.step - ) - trainer.eval_metrics_managers['eval3'].write_scalar.assert_called_with( - 'timing/compilation_seconds', 1, trainer.train_state.step - ) - trainer.eval_metrics_managers['eval4'].write_scalar.assert_called_with( - 'timing/compilation_seconds', 1, trainer.train_state.step - ) - eval_step_args = trainer._partitioned_eval_step.lower.call_args_list[1:] - self.assertLen(eval_step_args, 3) - - eval1_call_args = eval_step_args[0][0] - self.assertLen(eval1_call_args, 2) - self.assertEqual(eval1_call_args[0], trainer.train_state) - test_utils.assert_same( - eval1_call_args[1], - { - 'i': np.zeros((2, 5), dtype=np.int32), - }, - ) - - eval2_call_args = eval_step_args[1][0] - self.assertLen(eval2_call_args, 2) - self.assertEqual(eval2_call_args[0], trainer.train_state) - test_utils.assert_same( - eval2_call_args[1], - { - 'j': np.zeros((), dtype=np.float32), - }, - ) - - eval3_call_args = eval_step_args[2][0] - self.assertLen(eval3_call_args, 2) - self.assertEqual(eval3_call_args[0], trainer.train_state) - test_utils.assert_same( - eval3_call_args[1], - { - 'k': np.zeros((4), dtype=np.float32), - }, - ) - - self.assertDictEqual( - trainer._compiled_eval_steps, - { - 'eval1': 'compiled1', - 'eval2': 'compiled2', - 'eval3': 'compiled2', - 'eval4': 'compiled3', - }, - ) - - @mock.patch('jax.value_and_grad', fake_value_and_grad_fn_without_weight_sum) - def test_accumulate_grads_microbatched_without_weight_sum_single_batch(self): - batch_iter = self.dataset.as_numpy_iterator() - batch = next(batch_iter) - num_microbatches = 1 - grad_accum, metrics, flax_mutables = ( - trainer_lib.accumulate_grads_microbatched( - self.test_trainer._model, - self.init_train_state, - batch, - self.test_trainer._base_rng, - num_microbatches, - ) - ) - - i = batch['i'].sum() - expected_grad_accum = jax.tree.map( - lambda x: i, self.init_train_state - ).params - self.assertEqual(expected_grad_accum, grad_accum) - self.assertEqual(metrics['loss'].compute(), 2) - self.assertEqual(metrics['accuracy'].compute(), 2) - self.assertIsNone(flax_mutables) - - @mock.patch('jax.value_and_grad', fake_value_and_grad_fn_without_weight_sum) - def test_accumulate_grads_microbatched_without_weight_sum_multiple_batches( - self, - ): - batch_iter = self.dataset.as_numpy_iterator() - batch = next(batch_iter) - num_micro_batches = 2 - grad_accum, metrics, flax_mutables = ( - trainer_lib.accumulate_grads_microbatched( - self.test_trainer._model, - self.init_train_state, - batch, - self.test_trainer._base_rng, - num_micro_batches, - ) - ) - - expected_grad_accum = {'bias': jnp.ones(4), 'kernel': jnp.ones((2, 4))} - chex.assert_trees_all_equal(expected_grad_accum, grad_accum) - self.assertEqual(metrics['loss'].compute(), 2) - self.assertEqual(metrics['accuracy'].compute(), 2) - self.assertIsNone(flax_mutables) - - def test_eval_step_without_weight_sum(self): - batch_iter = self.dataset.as_numpy_iterator() - batch = next(batch_iter) - self.test_trainer._model.eval_fn = fake_eval_fn_without_weight_sum - metrics = trainer_lib.eval_step( - self.test_trainer._model, self.init_train_state, batch - ) - - self.assertEqual(metrics['loss'].compute(), 1) - self.assertEqual(metrics['accuracy'].compute(), 1) - - -class TrainerRngDeterminismTest(parameterized.TestCase): - - def create_trainer(self, step, random_seed): - init_optimizer = optimizers.Optimizer( - optimizers.sgd(0.1), - state=optimizers.OptimizerState( - step=step, param_states={'bias': 0, 'kernel': 0} - ), - target={'bias': np.zeros(4), 'kernel': np.zeros((2, 4))}, - ) - init_train_state = train_state_lib.FlaxOptimTrainState(init_optimizer) - train_state_axes = jax.tree.map(lambda x: None, init_train_state) - - test_trainer = trainer_lib.Trainer( - mock.create_autospec(models_lib.BaseModel, instance=True), - init_train_state, - partitioning.PjitPartitioner(num_partitions=1), - eval_names=['task1', 'task2'], - summary_dir=None, - train_state_axes=train_state_axes, - rng=jax.random.PRNGKey(random_seed), - learning_rate_fn=lambda step: 2 * step, - num_microbatches=None, - ) - return test_trainer - - @mock.patch('t5x.trainer.accumulate_grads_microbatched') - @mock.patch('t5x.trainer.apply_grads', fake_apply_grads) - def test_rng_determinism(self, mock_accum_grads): - - def fake_accum_grads_rng( - model, optimizer, batch, rng, num_microbatches, data_partition_spec - ): - del model, batch, num_microbatches, data_partition_spec - # Add 1, which will increment the step as a side effect. - grad_accum = jax.tree.map(lambda x: 1, optimizer) - m = {'rng': metrics_lib.Sum(jnp.sum(jax.random.key_data(rng)))} - return grad_accum, m, None - - mock_accum_grads.side_effect = fake_accum_grads_rng - # Create a trainer at a given step (53) with a given random seed (23), - # train up to a given train step (100), check the sum of the rngs from the - # metrics. - start_step = 47 - end_step = 100 - random_seed = 23 - trainer = self.create_trainer(step=start_step, random_seed=random_seed) - # 500 batches of size 2 - ds = [np.zeros(2)] * 500 - - metrics = trainer.train(iter(ds), num_steps=end_step - start_step) - base_rng = jax.random.PRNGKey(random_seed) - expected_rng_sum = np.sum( - [ - jax.random.key_data(jax.random.fold_in(base_rng, i)) - for i in range(start_step, end_step) - ], - dtype=np.uint32, - ) - np.testing.assert_array_equal( - metrics.result()['rng'].value, expected_rng_sum - ) - - -def fake_mut_accum_grads( - model, optimizer, batch, rng, num_microbatches, data_partition_spec -): - del model, num_microbatches, rng, data_partition_spec - # Add `i` to each optimzer value. - i = batch['i'].sum() - grad_accum = jax.tree.map(lambda x: i, optimizer) - # Add j to each metric. - j = batch['j'].sum() - metrics = { - 'loss': metrics_lib.Sum.from_model_output(j), - 'accuracy': metrics_lib.Sum.from_model_output(j), - } - return grad_accum, metrics, {'mutables': 0} - - -def fake_mut_apply_grads( - optimizer, - grad_accum, - metrics, - learning_rate, - weight_metrics_computer, - other_state_variables, -): - del weight_metrics_computer, other_state_variables - metrics['learning_rate'] = clu.metrics.Average.from_model_output( - learning_rate - ) - optimizer = jax.tree.map(lambda x, g: x + g, optimizer, grad_accum) - return optimizer, metrics - - -class MutableTrainerTest(parameterized.TestCase): - - def setUp(self): - super().setUp() - self.init_optimizer = optimizers.Optimizer( - optimizers.sgd(0.1), - state=optimizers.OptimizerState( - step=0, param_states={'bias': 0, 'kernel': 0} - ), - target={'bias': np.zeros(4), 'kernel': np.zeros((2, 4))}, - ) - self.init_train_state = train_state_lib.FlaxOptimTrainState( - _optimizer=self.init_optimizer, - flax_mutables=FlaxMutables( - variables={ - 'keys': np.zeros((10, 2)), - 'values': np.zeros((10, 5)), - } - ), - ) - train_state_axes = jax.tree.map(lambda x: None, self.init_train_state) - model_dir = self.create_tempdir().full_path - - mapfn = lambda i: {'i': [tf.cast(i, tf.int32)], 'j': [tf.cast(1, tf.int32)]} - self.dataset = ( - tf.data.Dataset.range(6).map(mapfn).batch(2, drop_remainder=True) - ) - - self.test_trainer = trainer_lib.Trainer( - mock.create_autospec(models_lib.BaseModel, instance=True), - self.init_train_state, - partitioning.PjitPartitioner(num_partitions=1), - eval_names=['task1', 'task2'], - summary_dir=model_dir, - train_state_axes=train_state_axes, - rng=np.ones(2, np.uint32), - learning_rate_fn=lambda step: 2 * (step + 1), - num_microbatches=None, - ) - - @mock.patch('t5x.trainer._time') - @mock.patch('t5x.trainer.accumulate_grads_microbatched', fake_mut_accum_grads) - @mock.patch('t5x.trainer.apply_grads', fake_mut_apply_grads) - # avoids calls time.time() during logging - @mock.patch('absl.logging.info', lambda *_: None) - @mock.patch('absl.logging.log_every_n_seconds', lambda *_: None) - def test_train(self, mock_time=None): - trainer = self.test_trainer - initial_rng = trainer._base_rng - - trainer._partitioned_train_step = mock.Mock( - side_effect=trainer._partitioned_train_step - ) - - # train start, logging, train end, logging - mock_time.side_effect = [1, 5, 5, 5] - num_steps = 1 - ds_iter = self.dataset.as_numpy_iterator() - batch = next(ds_iter) - train_state, _ = trainer._partitioned_train_step(trainer.train_state, batch) - - expected_train_state = jax.tree.map( - lambda x: np.array(x + 1), self.init_train_state - ) - # Base rng must remain the same - np.testing.assert_array_equal(trainer._base_rng, initial_rng) - jax.tree.map(np.testing.assert_equal, train_state, expected_train_state) - - self.assertIsNone(trainer._compiled_train_step) - self.assertEqual(trainer._partitioned_train_step.call_count, num_steps) - - @mock.patch('jax.value_and_grad', fake_value_and_grad_fn_without_weight_sum) - def test_accumulate_grads_microbatched_without_weight_sum_single_batch(self): - batch_iter = self.dataset.as_numpy_iterator() - batch = next(batch_iter) - num_microbatches = 1 - grad_accum, metrics, flax_mutables = ( - trainer_lib.accumulate_grads_microbatched( - self.test_trainer._model, - self.init_train_state, - batch, - self.test_trainer._base_rng, - num_microbatches, - ) - ) - - i = batch['i'].sum() - expected_grad_accum = jax.tree.map( - lambda x: i, self.init_train_state - ).params - self.assertEqual(expected_grad_accum, grad_accum) - self.assertEqual(metrics['loss'].compute(), 2) - self.assertEqual(metrics['accuracy'].compute(), 2) - self.assertIsNotNone(flax_mutables) - - @mock.patch( - 'jax.value_and_grad', fake_value_and_grad_fn_wo_weight_sum_w_mutables - ) - def test_accumulate_grads_microbatched_without_weight_sum_multiple_batches( - self, - ): - batch_iter = self.dataset.as_numpy_iterator() - batch = next(batch_iter) - num_micro_batches = 2 - grad_accum, metrics, flax_mutables = ( - trainer_lib.accumulate_grads_microbatched( - self.test_trainer._model, - self.init_train_state, - batch, - self.test_trainer._base_rng, - num_micro_batches, - ) - ) - - expected_grad_accum = {'bias': jnp.ones(4), 'kernel': jnp.ones((2, 4))} - chex.assert_trees_all_equal(expected_grad_accum, grad_accum) - self.assertEqual(metrics['loss'].compute(), 2) - self.assertEqual(metrics['accuracy'].compute(), 2) - self.assertIsNotNone(flax_mutables) - - def test_eval_step(self): - batch_iter = self.dataset.as_numpy_iterator() - batch = next(batch_iter) - self.test_trainer._model.eval_fn = fake_eval_fn_with_mutables - metrics = trainer_lib.eval_step( - self.test_trainer._model, self.init_train_state, batch - ) - - self.assertEqual(metrics['loss'].compute(), 1) - self.assertEqual(metrics['accuracy'].compute(), 1) - - def tearDown(self) -> None: - # Manually close managers to avoid phantom threads crossing test cases. - self.test_trainer.train_metrics_manager.close() - for mm in self.test_trainer.eval_metrics_managers.values(): - mm.close() - return super().tearDown() - - -if __name__ == '__main__': - absltest.main() diff --git a/t5x-main/t5x/utils.py b/t5x-main/t5x/utils.py deleted file mode 100644 index 15c872de7c8cb9d4057704ef522bcbad794c582a..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/utils.py +++ /dev/null @@ -1,2388 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""General utility functions for t5x.""" - -import collections -import collections.abc -from concurrent.futures import thread -import contextlib -import dataclasses -import functools -import importlib -import inspect -import os -import re -import time -import typing -from typing import Any, Callable, Iterable, Mapping, Optional, Sequence, Tuple, Type, Union -import warnings - -from absl import app # pylint: disable=unused-import -from absl import flags -from absl import logging -import airio.core as airio -import clu.data -import flax -from flax import traverse_util -import flax.core -from flax.core import scope as flax_scope -from flax.linen import partitioning as flax_partitioning -import jax -from jax.experimental import multihost_utils -import jax.numpy as jnp -import numpy as np -import orbax.checkpoint -import seqio -from t5x import checkpoints -from t5x import models -from t5x import optimizers -from t5x import partitioning -from t5x import state_utils -from t5x import train_state as train_state_lib -import tensorflow as tf -from tensorflow.io import gfile -import typing_extensions - - -FLAGS = flags.FLAGS - -# Remove _ShardedDeviceArray when users of t5x have their types updated -_ShardedDeviceArray = Any -Array = Union[np.ndarray, jnp.ndarray, _ShardedDeviceArray, tf.Tensor] -PyTree = Any -PartitionSpec = partitioning.PartitionSpec -DType = Union[np.dtype, type(jnp.bfloat16)] -Shape = Tuple[int, ...] - -# TODO(adarob): Remove namespace mapping after client gin files are updated. -TensorBoardLogger = seqio.TensorBoardLogger -get_local_data = checkpoints.get_local_data - - -class EvaluatorConstructor(typing_extensions.Protocol): - """A function that returns an Evaluator. - - This protocol represents the actual callsite for the seqio.Evaluator c'tor - in this file. It allows users to bind additional args with partial() and - pass that partial into the fn without causing type check issues. - """ - - def __call__( - self, - mixture_or_task_name: str, - feature_converter: seqio.FeatureConverter, - eval_split: str, - use_cached: bool, - seed: Optional[int], - sequence_length: Optional[Mapping[str, int]], - log_dir: Optional[str], - use_memory_cache: bool, - ) -> seqio.Evaluator: - """The call for the seqio.Evaluator c'tor in this file. - - Args: - mixture_or_task_name: a registered task or mixture name. - feature_converter: a feature converter object to use to convert the task - features to model features. Must be a subclass of - seqio.FeatureConverter. - eval_split: evaluation split. Typically "validation" or "test". - use_cached: whether to use the cached dataset instead of processing it on - the fly. - seed: random seed used for dataset shuffle and preprocessing. This is - usually not needed since eval datasets aren't shuffled and shouldn't use - stochastic operations. It is only useful for in certain data sources - such as `FewshotDataSource` where the training examples are randomly - selected during evaluation. - sequence_length: an optional length specification. If specified, these - will be the hard-limit on the evaluation data used for prediction. If - none of the preprocessors depend on the sequence length, it can be left - unspecified and the maximum length for each feature will be used. These - lengths are computed while caching the datasets. - log_dir: the directory to log outputs to. Required if `logger_cls` is - non-empty. - use_memory_cache: whether to use tf.data.Dataset#cache. May cause memory - issues for large datasets. - - Returns: - A seqio.Evaluator. - """ - ... - - -# ----------------------------------------------------------------------------- -# Configurations -# ----------------------------------------------------------------------------- -@dataclasses.dataclass -class SaveCheckpointConfig: - """Configuration for saving model checkpoints.""" - - # The dtype to save ('float32' or 'bfloat16'). - dtype: str = 'float32' - # Number of steps between writing checkpoints. - period: Optional[int] = None - # List of training steps (inputted as integers) to save checkpoints for - checkpoint_steps: Optional[Sequence[int]] = None - # Number of most recent checkpoints to keep, or None to keep them all. - keep: Optional[int] = None - # Number of dataset checkpoints to keep, or None to keep them all. - # Note: Dataset checkpoints are also affected by `keep`. - keep_dataset_checkpoints: Optional[int] = None - # Whether to save dataset checkpoints. - save_dataset: bool = False - # The checkpointer class to use. - checkpointer_cls: checkpoints.CheckpointerConstructor = ( - checkpoints.Checkpointer - ) - # Transformations to apply, in order, to the state before writing. - state_transformation_fns: Sequence[checkpoints.SaveStateTransformationFn] = ( - dataclasses.field(default_factory=list) - ) - # CheckpointManager implementation to use. - checkpoint_manager_cls: checkpoints.CheckpointManagerConstructor = ( - checkpoints.OrbaxCheckpointManagerInterface - ) - - def __post_init__(self): - if self.dtype not in (None, 'float32', 'bfloat16'): - raise ValueError( - "`SaveCheckpointConfig.dtype` must be one of None, 'float32' or " - f"'bfloat16'. Got {self.dtype}." - ) - - -@dataclasses.dataclass -class RestoreCheckpointConfig: - """Configuration for restoring model from checkpoint.""" - - # Path(s) to checkpoint to restore from or directory (depending on `mode`). - path: Union[str, Sequence[str]] - # One of 'specific', 'latest', or 'all'. - # specific: load the checkpoint specified by `path`. - # latest: load most recent checkpoint in the directory specified by `path`. - # all: sequentially load all of checkpoints in the directory `path`. - mode: str = 'latest' - # An optional sequence of (pattern, replacement) regex pairs. The pattern - # matches parameters in the model and the replacement matches the checkpoint - # (after substitutions). The replacement may be None, in which case the - # parameter can be dropped. Use `fallback_to_scratch` to fill them in with - # newly initialized values. - assignment_map: Optional[Sequence[Tuple[str, Optional[str]]]] = None - # Whether to restore all optimizer parameters from the checkpoint. - strict: bool = True - # Whether to initialize parameters that are in the model being restored but - # are missing from the checkpoint (after `assignment_map` is applied). - fallback_to_scratch: bool = False - # The dtype to restore ('float32' or 'bfloat16'), or None to load as saved. - dtype: Optional[str] = None - # Whether to restore the dataset checkpoint. Fails if checkpoint not present. - restore_dataset: bool = False - # The checkpointer class to use. - checkpointer_cls: checkpoints.CheckpointerConstructor = ( - checkpoints.Checkpointer - ) - # Transformations to apply, in order, to the state after reading. These will - # be applied after the `assignment_map` transformations. - state_transformation_fns: Sequence[ - checkpoints.RestoreStateTransformationFn - ] = () - # CheckpointManager implementation to use. - checkpoint_manager_cls: checkpoints.CheckpointManagerConstructor = ( - checkpoints.OrbaxCheckpointManagerInterface - ) - - def __post_init__(self): - if self.mode not in ('specific', 'latest', 'all'): - raise ValueError( - "`RestoreCheckpointConfig.mode` must be one of 'specific', 'latest', " - f"or 'all'. Got {self.mode}." - ) - if self.dtype not in (None, 'float32', 'bfloat16', 'float16'): - raise ValueError( - "`RestoreCheckpointConfig.dtype` must be one of `None`, 'float32', " - f"'float16' or 'bfloat16'. Got {self.dtype}." - ) - if self.assignment_map is not None: - # Turns `assignment_map` into a transformation function. - assignment_map_fn = functools.partial( - state_utils.apply_assignment_map, assignment_map=self.assignment_map - ) - # Prepends the `assignment_map` transformation to the front of the list. - self.state_transformation_fns = ( - assignment_map_fn, - *self.state_transformation_fns, - ) - - -@dataclasses.dataclass -class CheckpointConfig: - """Configuration for checkpointing of model and dataset.""" - - save: Optional[SaveCheckpointConfig] = None - restore: Optional[RestoreCheckpointConfig] = None - - -class LegacyCheckpointer(orbax.checkpoint.Checkpointer): - """Implementation of Checkpointer interface for T5X. - - Relies on underlying save_checkpointer and restore_checkpointer, which are - t5x.checkpoints.Checkpointer objects. - """ - - def __init__( - self, - *, - save_checkpointer: Optional[checkpoints.Checkpointer] = None, - restore_checkpointer: checkpoints.Checkpointer, - strict: Optional[bool] = False, - ): - self._save_checkpointer = save_checkpointer - self._restore_checkpointer = restore_checkpointer - self._strict = strict - - def save( - self, - path: str, - item: train_state_lib.TrainState, - force: bool = False, - custom_metadata: dict[str, Any] | None = None, - state_transformation_fns: Sequence[ - checkpoints.SaveStateTransformationFn - ] = (), - *, - concurrent_gb: int = 128, - ): - """Performs save operation using save_checkpointer. - - Args: - path: path to save item to. - item: a TrainState PyTree to save. - force: unused. - custom_metadata: unused. - state_transformation_fns: Transformations to apply, in order, to the state - before writing. - concurrent_gb: the approximate number of gigabytes of partitionable - parameters to process in parallel. Useful to preserve RAM. - """ - train_state = item - del path # stored in save_checkpointer - # dataset_iterator is also saved, but is provided in checkpointer init - if self._save_checkpointer is None: - raise ValueError( - "`_save_checkpointer` is not set up. Can't save checkpoints." - ) - self._save_checkpointer.save( - train_state, state_transformation_fns, concurrent_gb=concurrent_gb - ) - - def restore( - self, - path: str, - item: Optional[train_state_lib.TrainState] = None, - state_transformation_fns: Sequence[ - checkpoints.RestoreStateTransformationFn - ] = (), - fallback_state: Optional[Mapping[str, Any]] = None, - lazy_parameters: bool = False, - ) -> train_state_lib.TrainState: - """Performs restore operation using restore_checkpointer. - - Determines whether the indicated path is a Tensorflow checkpoint. - - Args: - path: the string path to restore from. - item: a TrainState PyTree to restore. Unused. - state_transformation_fns: Transformations to apply, in order, to the state - before writing. - fallback_state: a state dict of an optimizer to fall back to for loading - params that do not exist in the checkpoint (after applying all - `state_transformation_fns`), but do exist in `Checkpointer.optimizer`. - The union of `fallback_state` and state loaded from the checkpoint must - match `Checkpointer.optimizer`. - lazy_parameters: whether to load the parameters as LazyArrays to preserve - memory. - - Returns: - The restored train state. - """ - del item # not needed for restore in T5X - from_tensorflow = gfile.exists(path + '.index') - if from_tensorflow and state_transformation_fns: - raise ValueError( - 'Cannot initialize from a TensorFlow checkpoint using ' - '`state_transformation_fns`.' - ) - if from_tensorflow: - logging.info( - 'Initializing parameters from TensorFlow checkpoint %s', path - ) - return self._restore_checkpointer.restore_from_tf_checkpoint( - path, strict=self._strict - ) - return self._restore_checkpointer.restore( - path=path, - state_transformation_fns=state_transformation_fns, - fallback_state=fallback_state, - lazy_parameters=lazy_parameters, - ) - - -class LegacyCheckpointManager(orbax.checkpoint.CheckpointManager): - """Implementation of CheckpointManager interface for T5X. - - Uses underlying LegacyCheckpointer to handle save/restore for Dataset and - TrainState. - """ - - def __init__( - self, - *, - save_cfg: Optional[SaveCheckpointConfig], - restore_cfg: Optional[RestoreCheckpointConfig], - train_state_shape: train_state_lib.TrainState, - partitioner: partitioning.BasePartitioner, - ds_iter: Optional[ - Union[tf.data.Iterator, clu.data.dataset_iterator.DatasetIterator] - ] = None, - model_dir: Optional[str] = None, - ): - if save_cfg is not None: - if save_cfg.save_dataset: - assert ds_iter is not None - save_checkpointer = save_cfg.checkpointer_cls( # pytype: disable=wrong-arg-types # jnp-type - train_state=train_state_shape, - partitioner=partitioner, - checkpoints_dir=model_dir, - dataset_iterator=ds_iter if save_cfg.save_dataset else None, - save_dtype=save_cfg.dtype, - keep=save_cfg.keep, - keep_dataset_checkpoints=save_cfg.keep_dataset_checkpoints, - ) - else: - save_checkpointer = None - - if restore_cfg: - restore_checkpointer = restore_cfg.checkpointer_cls( - train_state=train_state_shape, - partitioner=partitioner, - checkpoints_dir='', # unused for restore - dataset_iterator=ds_iter if restore_cfg.restore_dataset else None, - restore_dtype=jnp.dtype(restore_cfg.dtype) - if restore_cfg.dtype - else None, - ) - strict = restore_cfg.strict - else: - restore_checkpointer = None - strict = False - - self._checkpointer = LegacyCheckpointer( - save_checkpointer=save_checkpointer, - restore_checkpointer=restore_checkpointer, - strict=strict, - ) - - def wait_until_finished(self): - pass - - def close(self): - pass - - def save( - self, - train_state: train_state_lib.TrainState, - state_transformation_fns: Sequence[ - checkpoints.SaveStateTransformationFn - ] = (), - ): # pytype: disable=signature-mismatch - """Performs save operation. - - Args: - train_state: a TrainState PyTree to save. - state_transformation_fns: Transformations to apply, in order, to the state - before writing. - """ - self._checkpointer.save( - path='', # not used - item=train_state, - state_transformation_fns=state_transformation_fns, - ) - - def restore( - self, - paths: Sequence[str], - restore_cfg: Optional[RestoreCheckpointConfig] = None, - fallback_state: Optional[Mapping[str, Any]] = None, - ) -> Optional[ - Union[train_state_lib.TrainState, Sequence[train_state_lib.TrainState]] - ]: # pytype: disable=signature-mismatch - """Performs restore operation using restore_checkpointer. - - Determines whether the indicated path is a Tensorflow checkpoint. - - Args: - paths: A sequence of paths to restore from. - restore_cfg: RestoreCheckpointConfig specifying restoration information. - fallback_state: a state dict of an optimizer to fall back to for loading - params that do not exist in the checkpoint (after applying all - `state_transformation_fns`), but do exist in `Checkpointer.optimizer`. - The union of `fallback_state` and state loaded from the checkpoint must - match `Checkpointer.optimizer`. - - Returns: - The restored TrainState if only one TrainState can be restored from the - given paths, otherwise a sequence of TrainStates. May return None. - """ - if restore_cfg is None or paths is None: - return None - - restored = [] - for path in paths: - logging.info( - 'Initializing parameters from specific T5X checkpoint %s', path - ) - restored.append( - self._checkpointer.restore( - path=path, - item=None, # not used - state_transformation_fns=restore_cfg.state_transformation_fns, - fallback_state=fallback_state, - ) - ) - - if len(restored) == 1: - restored = restored[0] - return restored - - -def restore( - checkpoint_manager: checkpoints.OrbaxCheckpointManagerInterface, - paths: Sequence[str], - restore_cfg: RestoreCheckpointConfig, - fallback_state: Optional[Mapping[str, Any]] = None, -) -> Union[train_state_lib.TrainState, Sequence[train_state_lib.TrainState]]: - """Performs restore operation using restore_checkpointer. - - Determines whether the indicated path is a Tensorflow checkpoint. - - Args: - checkpoint_manager: OrbaxCheckpointManagerInterface - paths: A sequence of paths to restore from. - restore_cfg: RestoreCheckpointConfig specifying restoration information. - fallback_state: a state dict of an optimizer to fall back to for loading - params that do not exist in the checkpoint (after applying all - `state_transformation_fns`), but do exist in `Checkpointer.optimizer`. The - union of `fallback_state` and state loaded from the checkpoint must match - `Checkpointer.optimizer`. - - Returns: - The restored TrainState if only one TrainState can be restored from the - given paths, otherwise a sequence of TrainStates. - """ - if restore_cfg is None or paths is None: - return None - - state_transformation_fns = restore_cfg.state_transformation_fns - restored_checkpoints = [] - for path in paths: - logging.info( - 'Initializing parameters from specific T5X checkpoint %s', path - ) - - from_tensorflow = gfile.exists(path + '.index') - if from_tensorflow and state_transformation_fns: - raise ValueError( - 'Cannot initialize from a TensorFlow checkpoint using ' - '`state_transformation_fns`.' - ) - if from_tensorflow: - logging.info( - 'Initializing parameters from TensorFlow checkpoint %s', path - ) - return checkpoint_manager.restore_from_tf_checkpoint( - path, strict=restore_cfg.strict - ) - - restored = checkpoint_manager.restore( - path=path, - state_transformation_fns=state_transformation_fns, - fallback_state=fallback_state, - ) - restored_checkpoints.append(restored) - - if len(restored_checkpoints) == 1: - restored_checkpoints = restored_checkpoints[0] - return restored_checkpoints - - -@dataclasses.dataclass -class DatasetConfig: - """Configuration for loading a dataset from a SeqIO Task or Mixture.""" - - mixture_or_task_name: Union[ - str, seqio.Task, seqio.Mixture, airio.Task, airio.Mixture - ] - task_feature_lengths: Mapping[str, int] - split: str - batch_size: int # Number of examples per batch. - shuffle: bool - seed: Optional[int] - # Whether to use a precomputed version of the dataset from a cache dir. - use_cached: bool = False - pack: bool = False - # Whether to use tensor2tensor custom ops for more efficient packing. - use_custom_packing_ops: bool = False - # An optional module to import for registering the referenced Mixture or Task. - # DEPRECATED. - module: Optional[str] = None - # Whether to cache the dataset in memory (only applies to evaluation data). - use_memory_cache: bool = True - # Whether to trim output features from tasks. - trim_output_features: bool = True - ### AirIO-only ### - # A list of runtime preprocessors to pass to airio. Generally used - # to configure feature converters and packing. Ignored for non-airio configs. - runtime_preprocessors: Sequence[Any] | None = None - # The number of threads reading from the data source in parallel. Passing None - # or 0 will use the default number of threads. - num_prefetch_threads: int | None = None - # Number of Python worker processes. More processes can speed up - # the pipeline if it's compute bound and bottlenecked on the CPython's GIL. - # 0 means no Python multiprocessing. All data loading and transformation - # will run in the main Python process. - num_workers: int | None = 0 - - -def _hashed_index(x) -> int: - # This works for both `pjit`/`xmap` indices and `pmap` indices (which might - # have an integer instead of a slice). - assert all(v.step is None for v in x if isinstance(v, slice)) - return hash( - tuple((v.start, v.stop) if isinstance(v, slice) else v for v in x) - ) - - -def _get_index_mappings(device_to_idxs): - """Get device and host to index set mappings for GDA construction.""" - host_to_idxs = collections.defaultdict(list) - idx_to_devices = collections.defaultdict(list) - for d, idx in device_to_idxs.items(): - hashed_idx = _hashed_index(idx) - # Only need one copy of each idx, since they are unique. Need to maintain - # original ordering though. - if hashed_idx not in host_to_idxs[d.process_index]: - host_to_idxs[d.process_index].append(hashed_idx) - # Index may correspond to multiple devices. - idx_to_devices[hashed_idx].append(d) - - assert jax.process_index() in host_to_idxs - for h1, idxs1 in host_to_idxs.items(): - for idx in idxs1: - assert idx in idx_to_devices - for h2, idxs2 in host_to_idxs.items(): - if h1 == h2: - continue - assert not (set(idxs1) & set(idxs2)) or set(idxs1) == set(idxs2) - - return host_to_idxs, idx_to_devices - - -def _create_sharded_array( - partitioner: partitioning.BasePartitioner, - global_shapes: PyTree, - host_arrays: PyTree, -) -> PyTree: - """Create jax.Array from input arrays. - - Example: - - Consider a case where the global input array has length 128. The global mesh - specifies that the data dimension be sharded into 8 shards. This means we want - shards of length 16. The data_layout, defined by the partitioner object, - specifies that the data should be divided into two shards, one per host. Each - host will have a local slice of the data (length 64). - - In this function, we will divide the local array into 4 shards of length 16. - Each of these will be placed onto a separate device. If the sharding had - specified only 4 global shards instead of 8, we would have divided our local - array into only 2 shards. In this case, the first shard would be placed on the - first two devices (replicated) and the second on the following two devices. - - Args: - partitioner: Partitioner object containing mesh and mesh_axes - global_shapes: PyTree matching host_arrays specifying global shape of each - array. - host_arrays: PyTree of LOCAL arrays (not global) that should be converted to - jax.Array. - - Returns: - PyTree matching host_arrays of jax.Array. - """ - global_mesh = partitioner.mesh - axes = partitioner.data_partition_spec - local_devices = global_mesh.local_devices - local_device_count = jax.local_device_count() - - # Global input array is already split into per-host shards. - def _put_to_devices(x, global_shape): - # Mapping of device to index slice from *global* array. - - device_to_idxs = jax.sharding.NamedSharding( - global_mesh, axes - ).devices_indices_map(global_shape) - # Mapping of host to a set of unique index slices for that host. - # Mapping of index slice to a list of devices onto which the slice should be - # placed. - host_to_idxs, idx_to_devices = _get_index_mappings(device_to_idxs) - - shard_length = jax.sharding.NamedSharding(global_mesh, axes).shard_shape( - global_shape - )[0] - num_shards = len(x) // shard_length - try: - local_array_shards = np.split(x, num_shards, axis=0) - except ValueError as array_split_error: - raise ValueError( - f'Unable to put to devices shape {x.shape} with ' - f'local device count {local_device_count}' - ) from array_split_error - - # Construct mapping of device to index in the split local array. - device_to_split_array_idx = {} - i = 0 - for idx in host_to_idxs[jax.process_index()]: - assert idx in idx_to_devices - for d in idx_to_devices[idx]: - device_to_split_array_idx[d] = i % len(local_array_shards) - i += 1 - - device_buffers = [] - for d in local_devices: - assert d in device_to_split_array_idx - i = device_to_split_array_idx[d] - device_buffers.append(jax.device_put(local_array_shards[i], d)) - - return device_buffers - - device_buffers = jax.tree.map(_put_to_devices, host_arrays, global_shapes) - - def _jax_array(dbs, global_shape): - return jax.make_array_from_single_device_arrays( - global_shape, jax.sharding.NamedSharding(global_mesh, axes), dbs - ) - - return jax.tree.map( - _jax_array, - device_buffers, - global_shapes, - is_leaf=lambda x: isinstance(x, (list, tuple)), - ) - - -class ShardedDatasetIterator(clu.data.dataset_iterator.DatasetIterator): - """A wrapper iterator that returns sharded arrays.""" - - def __init__( - self, - iterator: clu.data.dataset_iterator.DatasetIterator, - partitioner: partitioning.BasePartitioner, - global_shapes: PyTree, - ): - self._iterator = iterator - self._global_shapes = global_shapes - self._partitioner = partitioner - - def __next__(self): - return _create_sharded_array( - self._partitioner, self._global_shapes, next(self._iterator) - ) - - def reset(self): - return self._iterator.reset() - - @property - def element_spec(self): - return self._iterator.element_spec - - def save(self, filename): - return self._iterator.save(filename) - - def restore(self, filename): - return self._iterator.restore(filename) - - @property - def iterator(self): - if isinstance(self._iterator, clu.data.dataset_iterator.TfDatasetIterator): - return self._iterator.iterator - return self._iterator - - -def prepare_train_iter( - train_iter: Union[ - tf.data.Dataset, clu.data.dataset_iterator.DatasetIterator - ], - *, - partitioner, - checkpoint_cfg, - data_layout, -) -> clu.data.dataset_iterator.PeekableDatasetIterator: - """Prepares the training input iterator.""" - if isinstance(train_iter, airio.AirIODatasetIterator): - return train_iter - if isinstance(train_iter, tf.data.Dataset): - train_iter = clu.data.dataset_iterator.TfDatasetIterator( - train_iter, checkpoint=True - ) - elif not isinstance(train_iter, clu.data.dataset_iterator.DatasetIterator): - raise ValueError( - f'get_dataset_fn returned unsupported type {type(train_iter)}.' - ) - - - input_shapes = jax.tree.map( - lambda x: (data_layout.batch_size, *x.shape[1:]), train_iter.element_spec - ) - train_iter = ShardedDatasetIterator(train_iter, partitioner, input_shapes) - return clu.data.dataset_iterator.PeekableDatasetIterator(train_iter) - - -def sync_global_devices(name: str) -> None: - """Creates a barrier with given name across all hosts/devices.""" - # Internal mock TPU handling - multihost_utils.sync_global_devices(name) - - -def multihost_assert_equal(input_tree, fail_message: str = ''): - """Verifies that all the hosts have the same tree of values.""" - # Internal mock TPU handling - multihost_utils.assert_equal(input_tree, fail_message) - - -# ------------------------------------------------------------------------------ -# Fast *nondeterministic* hardware RNG for faster Dropout -# ------------------------------------------------------------------------------ -def _hardware_uniform( # pytype: disable=annotation-type-mismatch # jnp-type - rng_key: Array, - shape: Shape, - dtype: jnp.dtype = np.float32, - minval: Array = np.float32(0), - maxval: Array = np.float32(1), -) -> Array: - """Random uniform method that uses non-deterministic accelerator hardware.""" - del rng_key # non-deterministic prng. - minval = jax.lax.convert_element_type(minval, dtype) - maxval = jax.lax.convert_element_type(maxval, dtype) - return jax.lax.rng_uniform(minval, maxval, shape) - - -# For dropout-only hardware rng. -def _hardware_bernoulli( - rng_key: Array, - p: Union[np.ndarray, np.floating] = np.float32(0.5), - shape: Shape = (), -) -> Array: - del rng_key # non-deterministic prng. - return jax.lax.rng_uniform(0.0, 1.0, shape) < p - - -def set_hardware_rng_ops(): - jax.config.update('jax_default_prng_impl', 'unsafe_rbg') - - -# ----------------------------------------------------------------------------- -# Training utility functions. -# ----------------------------------------------------------------------------- - - -def get_zeros_batch_like_spec( - batch_spec: Mapping[str, jax.ShapeDtypeStruct] -) -> Mapping[str, jnp.ndarray]: - return {k: jnp.zeros(t.shape, t.dtype) for k, t in batch_spec.items()} - - -def get_zeros_batch_like_dataset( - dataset: Union[tf.data.Dataset, airio.AirIODatasetIterator], - batch_size=None, -) -> Mapping[str, jnp.ndarray]: - """Get zeros batch like the dataset spec.""" - reshape = lambda s: (batch_size,) + s[1:] if batch_size else tuple(s) - batch_spec = {} - for key, val in dataset.element_spec.items(): # pytype: disable=attribute-error - if isinstance(dataset, tf.data.Dataset): - static_attributes = jax.ShapeDtypeStruct( - reshape(val.shape), val.dtype.as_numpy_dtype - ) - else: - static_attributes = jax.ShapeDtypeStruct(val.shape, val.dtype) - batch_spec[key] = static_attributes - return get_zeros_batch_like_spec(batch_spec) - - -class InitFnCallable(typing_extensions.Protocol): - """A callable that initializes model variables.""" - - def __call__( - self, - rng: Array, - input_shapes: Mapping[str, Array], - input_types: Optional[Mapping[str, DType]], - ) -> flax_scope.FrozenVariableDict: - ... - - -class LearningRateCallable(typing_extensions.Protocol): - - def __call__(self, step: jnp.ndarray) -> jnp.ndarray: - ... - - -def create_learning_rate_scheduler( - factors: str = 'constant * linear_warmup * rsqrt_decay', - base_learning_rate: float = 0.5, - warmup_steps: int = 1000, - decay_factor: float = 0.5, - steps_per_decay: int = 20000, - steps_per_cycle: int = 100000, - step_offset: int = 0, - min_learning_rate: float = 1e-8, -) -> LearningRateCallable: - """Creates learning rate schedule. - - Interprets factors in the factors string which can consist of: - * constant: interpreted as the constant value, - * linear_warmup: interpreted as linear warmup until warmup_steps, - * linear_decay: linear decay from warmup_steps with decay_factor slope. Note - this option implies 'constant * linear_warmup', and should not be used in - in conjunction with `constant` or `linear_warmup` factors. - * rsqrt_decay: divide by square root of max(step, warmup_steps) - * rsqrt_normalized_decay: divide by square root of max(step/warmup_steps, 1) - * decay_every: Every k steps decay the learning rate by decay_factor. - * cosine_decay: Cyclic cosine decay, uses steps_per_cycle parameter. - - Args: - factors: string, factors separated by '*' that defines the schedule. - base_learning_rate: float, the starting constant for the lr schedule. - warmup_steps: int, how many steps to warm up for in the warmup schedule. - decay_factor: float, the amount to decay the learning rate by. - steps_per_decay: int, how often to decay the learning rate. - steps_per_cycle: int, steps per cycle when using cosine decay. - step_offset: int, an offset that the step parameters to this function are - relative to. - min_learning_rate: float, minimum learning rate to output. Useful for cases - when a decay function is (mis)configured to decay to non-positive values. - - Returns: - a function learning_rate(step): float -> {'learning_rate': float}, the - step-dependent lr. - """ - factors = [n.strip() for n in factors.split('*')] - - def step_fn(step: jnp.ndarray) -> jnp.ndarray: - """Step to learning rate function.""" - step = jnp.maximum(0, step - step_offset) - ret = 1.0 - for name in factors: - if name == 'constant': - ret *= base_learning_rate - elif name == 'linear_warmup': - ret *= jnp.minimum(1.0, step / warmup_steps) - elif name == 'linear_decay': - ret *= base_learning_rate * jnp.minimum( - step / warmup_steps, 1.0 + decay_factor * (warmup_steps - step) - ) - elif name == 'rsqrt_decay': - ret /= jnp.sqrt(jnp.maximum(step, warmup_steps)) - elif name == 'rsqrt_normalized_decay': - ret *= jnp.sqrt(warmup_steps) - ret /= jnp.sqrt(jnp.maximum(step, warmup_steps)) - elif name == 'decay_every': - ret *= decay_factor ** (step // steps_per_decay) - elif name == 'cosine_decay': - progress = jnp.maximum( - 0.0, (step - warmup_steps) / float(steps_per_cycle) - ) - ret *= jnp.maximum( - 0.0, 0.5 * (1.0 + jnp.cos(jnp.pi * (progress % 1.0))) - ) - else: - raise ValueError('Unknown factor %s.' % name) - ret = jnp.maximum(ret, min_learning_rate) - return jnp.asarray(ret, dtype=jnp.float32) - - return step_fn - - -def get_first_valid_restore_config_and_paths( - restore_cfgs: Sequence[RestoreCheckpointConfig], -) -> Tuple[Optional[RestoreCheckpointConfig], Sequence[str]]: - """Returns first valid restore_cfg and the paths to restore. - - Args: - restore_cfgs: a sequence of `RestoreCheckpointConfig` objects, which should - be filtered to determine the first valid object. - - Returns: - Tuple of valid RestoreCheckpointConfig and a sequence of paths. - If the first config encountered has mode 'specific', it is immediately - returned, along with its specified paths. - If the mode is 'all' or 'latest', checks to ensure that there are valid - checkpoints at each of the provided paths and filters the returned paths - accordingly. - """ - for restore_cfg in restore_cfgs: - paths = ( - [restore_cfg.path] - if isinstance(restore_cfg.path, str) - else restore_cfg.path - ) - if restore_cfg.mode == 'specific': - return restore_cfg, paths - elif restore_cfg.mode in ('all', 'latest'): - for ckpt_dir in paths: - if not gfile.isdir(ckpt_dir): - raise ValueError( - 'Checkpoint path(s) must be valid directories when using ' - "restore mode 'all' or 'latest'." - ) - # Check if this is a TensorFlow checkpoint dir. - tf_ckpt_state = tf.train.get_checkpoint_state(ckpt_dir) - - if tf_ckpt_state: - ckpt_paths = tf_ckpt_state.all_model_checkpoint_paths - else: - ckpt_paths = [ - os.path.join(ckpt_dir, f'checkpoint_{step}') - for step in checkpoints.all_steps(ckpt_dir) - ] - if not ckpt_paths: - logging.info( - 'No checkpoints found in specified directory: %s', ckpt_dir - ) - continue - if restore_cfg.mode == 'latest': - logging.info('Using latest T5X checkpoint.') - ckpt_paths = ckpt_paths[-1:] - return restore_cfg, ckpt_paths - else: - logging.error('Unsupported checkpoint restore mode: %s', restore_cfg.mode) - return None, [] - - -def get_fallback_state( - restore_cfg: RestoreCheckpointConfig, - init_fn: Callable[[jnp.ndarray], Mapping[str, Any]], - init_rng: Optional[jnp.ndarray], -) -> Optional[Mapping[str, Any]]: - """Returns the fallback_state that can be used in restore().""" - if restore_cfg is None: - return - if restore_cfg.fallback_to_scratch: - if not restore_cfg.state_transformation_fns: - raise ValueError( - '`state_transformation_fns` must be provided with ' - '`fallback_to_scratch`' - ) - if init_rng is None: - raise ValueError( - 'An `init_rng` must be provided with `fallback_to_scratch`' - ) - fallback_state = init_fn(init_rng) - else: - fallback_state = None - return fallback_state - - -class TrainStateInitializer: - """Helper for initializing partitioned TrainState from checkpoints or scratch. - - Common use cases: - - * To restore from a single checkpoint, use `from_checkpoint`. - * To iterate over multiple checkpoints without recompiling the model, - use `from_checkpoints`. - * To initialize from scratch, use `from_scratch`. - * To restore from a checkpoint with a fallback to initializing from scratch, - use `from_checkpoint_or_scratch`. - - Attributes: - global_train_state_shape: a TrainState containing the global (unpartitioned) - shape (in `jax.ShapeDtypeStruct`) of each parameter instead of its value. - train_state_axes: a TrainState object containing a PartitionSpec (or None) - for each parameter, in place of the parameter itself. - """ - - # TODO(adarob): Replace input_shapes and input_types with sample batch. - def __init__( - self, - optimizer_def: Optional[optimizers.OptimizerDefType], - init_fn: InitFnCallable, - input_shapes: Mapping[str, Array], - partitioner: partitioning.BasePartitioner, - input_types: Optional[Mapping[str, DType]] = None, - ): - """TrainStateInitializer constructor. - - Args: - optimizer_def: Optimizer def to be initialized, or None to create a - `InferenceState` without an optimizer. - init_fn: callable that initializes model variables from a PRNGKey and the - input shapes. - input_shapes: a mapping from key to array shape for each feature in the - global (unsharded) input batch. - partitioner: the partitioner to use. - input_types: a mapping from key to array type for each feature in the - global (unshared) input batch. If not provided, the type is assumed to - be `jnp.float32`. - """ - - def initialize_train_state(rng: Array): - initial_variables = flax.core.freeze( - init_fn(rng=rng, input_shapes=input_shapes, input_types=input_types) - ) - if optimizer_def: - return train_state_lib.FlaxOptimTrainState.create( - optimizer_def, initial_variables - ) - return train_state_lib.InferenceState.create(initial_variables) - - self._partitioner = partitioner - self.global_train_state_shape = jax.eval_shape( - initialize_train_state, rng=jax.random.PRNGKey(0) - ) - self.train_state_axes = partitioner.get_mesh_axes( - self.global_train_state_shape - ) - self._initialize_train_state = initialize_train_state - - # Currently scanned layers require passing annotations through to the - # point of the scan transformation to resolve an XLA SPMD issue. - - # init_fn is always(?) equal to model.get_initial_variables, fetch the model - # instance from the bound method. - model = init_fn.__self__ # pytype: disable=attribute-error - if ( - hasattr(model, 'module') - and hasattr(model.module, 'scan_layers') - and model.module.scan_layers - ): - if hasattr(model.module, 'spmd_annotations'): - # update top-level module with spmd annotations. - model.module = model.module.clone( - parent=None, spmd_annotations=self.train_state_axes.params - ) - - def from_scratch(self, init_rng: Array) -> train_state_lib.TrainState: - """Initializes the partitioned Optimizer from scratch.""" - logging.info('Initializing parameters from scratch.') - - # If pretraining and no checkpoint imported, we jit the (sharded-) init - # function to minimize fragmentation. We use the same partition - # setup as the training step/loop to initialize everything "in-place" and - # avoid communication or OOM. - p_initialize_train_state_fn = self._partitioner.partition( - self._initialize_train_state, - in_axis_resources=None, - out_axis_resources=self.train_state_axes, - ) - return p_initialize_train_state_fn(init_rng) - - # TODO(b/216650048) deprecate this function and use orbax. - def from_checkpoints( - self, - restore_cfgs: Sequence[RestoreCheckpointConfig], - ds_iter: Optional[tf.data.Iterator] = None, - init_rng: Optional[jnp.ndarray] = None, - ) -> Iterable[Tuple[train_state_lib.TrainState, str]]: - """Yields 0 or more restored partitioned Optimizers, and maybe datasets. - - The manner in which parameters are initialized depends on `restore_cfgs` and - `restore_cfgs` is iterated over and the first config that matches one or - more existing checkpoints is used to generate restored optimizers from the - checkpoint(s). Any remaining configs are ignored. - - Args: - restore_cfgs: ordered sequence of configurations specifying checkpoint(s) - to restore from. The first config to match a checkpoint will be used. - ds_iter: a tf.data.Iterator for the input data, or None. If provided, the - referenced iterator's state may be silently restored (depending on the - config's `restore_dataset` value) along with the optimizer. - init_rng: for initializing parameters from scratch when they are not - available in the checkpoint and `fallback_to_scratch` is True - - Yields: - TrainState with initialized optimizer, with parameters copied to devices. - Path to restored checkpoint. - """ - - def _restore_path(path, cfg): - if cfg is None: - raise ValueError('Expected valid `RestoreCheckpointConfig`.') - restore_checkpointer = cfg.checkpointer_cls( - train_state=self.global_train_state_shape, - partitioner=self._partitioner, - checkpoints_dir='', # unused for restore - dataset_iterator=ds_iter if cfg.restore_dataset else None, - restore_dtype=jnp.dtype(cfg.dtype) if cfg.dtype else None, - ) - - from_tensorflow = gfile.exists(path + '.index') - if from_tensorflow and cfg.state_transformation_fns: - raise ValueError( - 'Cannot initialize from a TensorFlow checkpoint using ' - '`state_transformation_fns`.' - ) - if from_tensorflow: - logging.info( - 'Initializing parameters from TensorFlow checkpoint %s', path - ) - return restore_checkpointer.restore_from_tf_checkpoint( - path, strict=cfg.strict - ) - - else: - fallback_state = get_fallback_state( - cfg, lambda rng: self.from_scratch(rng).state_dict(), init_rng - ) - - logging.info( - 'Initializing parameters from specific T5X checkpoint %s', path - ) - return restore_checkpointer.restore( - path=path, - state_transformation_fns=cfg.state_transformation_fns, - fallback_state=fallback_state, - ) - - restore_cfg, paths = get_first_valid_restore_config_and_paths(restore_cfgs) - for path in paths: - yield _restore_path(path, restore_cfg), path - - def from_checkpoint( - self, - ckpt_cfgs: Sequence[RestoreCheckpointConfig], - *, - ds_iter: Optional[tf.data.Iterator] = None, - init_rng: Optional[jnp.ndarray] = None, - ) -> Optional[train_state_lib.TrainState]: - """Restores (at most) 1 checkpoint using `from_checkpoints`, or dies.""" - train_states = [ - state - for state, _ in self.from_checkpoints( - ckpt_cfgs, ds_iter=ds_iter, init_rng=init_rng - ) - ] - if len(train_states) > 1: - raise ValueError( - f'Expected at most 1 checkpoint but got {len(train_states)} for ' - f'config(s): {ckpt_cfgs}' - ) - return (train_states[0]) if train_states else None - - def from_checkpoint_or_scratch( - self, - ckpt_cfgs: Sequence[RestoreCheckpointConfig], - *, - init_rng: Array, - ds_iter: Optional[tf.data.Iterator] = None, - ) -> Optional[train_state_lib.TrainState]: - """Initializes from checkpoint, if found, or from scratch.""" - return self.from_checkpoint( - ckpt_cfgs, ds_iter=ds_iter, init_rng=init_rng - ) or self.from_scratch(init_rng) - - -def create_checkpoint_manager_and_restore( - train_state_initializer: TrainStateInitializer, - partitioner: partitioning.BasePartitioner, - restore_checkpoint_cfg: RestoreCheckpointConfig, - restore_path: Optional[str], - fallback_init_rng: Optional[jnp.ndarray], - save_checkpoint_cfg: Optional[SaveCheckpointConfig] = None, - model_dir: Optional[str] = None, - ds_iter: Optional[ - Union[tf.data.Iterator, clu.data.dataset_iterator.DatasetIterator] - ] = None, - use_orbax: bool = False, -) -> Tuple[ - Optional[train_state_lib.TrainState], - Union[LegacyCheckpointManager, checkpoints.OrbaxCheckpointManagerInterface], -]: - """Creates a CheckpointManager and restores TrainState if available.""" - - def _init(rng): - return train_state_initializer.from_scratch(rng).state_dict() - - restore_path_list = [restore_path] if restore_path else [] - model_dir = model_dir or restore_checkpoint_cfg.path - if use_orbax: - checkpoint_manager = create_orbax_checkpoint_manager( - save_cfg=save_checkpoint_cfg, - restore_cfg=restore_checkpoint_cfg, - train_state=train_state_initializer.global_train_state_shape, - partitioner=partitioner, - ds_iter=ds_iter, - model_dir=model_dir, - ) - train_state = restore( - checkpoint_manager, - restore_path_list, - restore_checkpoint_cfg, - get_fallback_state(restore_checkpoint_cfg, _init, fallback_init_rng), - ) - else: - checkpoint_manager = LegacyCheckpointManager( - save_cfg=save_checkpoint_cfg, - restore_cfg=restore_checkpoint_cfg, - train_state_shape=train_state_initializer.global_train_state_shape, - partitioner=partitioner, - ds_iter=ds_iter, - model_dir=model_dir, - ) - train_state = checkpoint_manager.restore( - restore_path_list, - restore_checkpoint_cfg, - get_fallback_state(restore_checkpoint_cfg, _init, fallback_init_rng), - ) - - if isinstance(train_state, Sequence): - raise ValueError('Cannot restore multiple checkpoints.') - return train_state, checkpoint_manager - - -def create_orbax_checkpoint_manager( - *, - save_cfg: Optional[SaveCheckpointConfig] = None, - restore_cfg: Optional[RestoreCheckpointConfig] = None, - train_state: train_state_lib.TrainState, - partitioner: partitioning.BasePartitioner, - ds_iter: Optional[ - Union[tf.data.Iterator, clu.data.dataset_iterator.DatasetIterator] - ] = None, - model_dir: Optional[str] = None, -): - """Creates Orbax CheckpointManager.""" - if save_cfg is not None and restore_cfg is not None: - if ( - save_cfg.checkpoint_manager_cls - is not restore_cfg.checkpoint_manager_cls - ): - msg = ( - 'Must provide matching configurations of `checkpoint_manager_cls` in ' - '`save_cfg` and `restore_cfg`.' - ) - raise ValueError(msg) - - def _get_default_args(cls_or_fcn): - signature = inspect.signature(cls_or_fcn) - # Only get certain parameters needed for BestCheckpointManager - # configuration. These are the parameters of SaveBestCheckpointer that are - # not shared by regular Checkpointer. This whole approach is very hacky, but - # prevents us from needing to migrate every user to a new checkpoint config, - # which is the only alternative. - # Arguments aside from these should be set via CheckpointConfig, not gin. - - def _is_relevant_arg(key: str): - return key in { - 'metric_name_to_monitor', - 'metric_mode', - 'keep_checkpoints_without_metrics', - 'force_keep_period', - } - - return { - k: v.default - for k, v in signature.parameters.items() - if v.default is not inspect.Parameter.empty and _is_relevant_arg(k) - # Without the filtering by name specified above, we would have duplicate - # parameters being passed, which would give an error. - } - - def _get_extra_kwargs(cfg): - extra_kwargs = {} - # Sometimes, the user pass in a `functools.partial` of - # `SaveBestCheckpointer` which will cause issubclass to raise an Exception - # since `functools.partial` is not a class. - if isinstance(cfg.checkpointer_cls, functools.partial): - # Note, this is intentionally moved out of the above if statement compared - # to the condition below. This is because we need to handle the kwargs - # differently since it's a functools.partial. - extra_kwargs = _get_default_args(cfg.checkpointer_cls) - else: - if issubclass( - type(cfg.checkpointer_cls), checkpoints.SaveBestCheckpointer - ): - save_best_checkpointer = checkpoints.SaveBestCheckpointer( - train_state=train_state, - checkpoints_dir=model_dir, - partitioner=partitioner, - ) - extra_kwargs = { - 'metric_name_to_monitor': ( - save_best_checkpointer._metric_name_to_monitor # pylint: disable=protected-access - ), - 'metric_mode': save_best_checkpointer._metric_mode, # pylint: disable=protected-access - 'keep_checkpoints_without_metrics': ( - save_best_checkpointer._keep_checkpoints_without_metrics # pylint: disable=protected-access - ), - 'force_keep_period': save_best_checkpointer._force_keep_period, # pylint: disable=protected-access - } - return extra_kwargs - - save_dtype = None - restore_dtype = None - period = None - keep = None - should_save_restore_dataset = False - checkpoint_manager_cls = None - extra_kwargs = {} - if save_cfg is not None: - should_save_restore_dataset |= save_cfg.save_dataset - save_dtype = save_cfg.dtype - keep = save_cfg.keep - period = save_cfg.period - checkpoint_manager_cls = save_cfg.checkpoint_manager_cls - extra_kwargs = _get_extra_kwargs(save_cfg) - if restore_cfg is not None: - should_save_restore_dataset |= restore_cfg.restore_dataset - restore_dtype = restore_cfg.dtype - # Doesn't matter if we reset this, since they're required to be the same - # anyway. - checkpoint_manager_cls = restore_cfg.checkpoint_manager_cls - # If already set, configuration from save_cfg takes precedence. Give - # extra_kwargs a chance to be reset to something more specialized, if not - # specified in save_cfg - if not extra_kwargs: - extra_kwargs = _get_extra_kwargs(restore_cfg) - ds_iter = ds_iter if should_save_restore_dataset else None - - return checkpoint_manager_cls( - directory=model_dir, - train_state=train_state, - partitioner=partitioner, - dataset_iterator=ds_iter, - save_dtype=save_dtype, - restore_dtype=restore_dtype, - keep=keep, - period=period, - **extra_kwargs, - ) - - -# ----------------------------------------------------------------------------- -# Logging utility functions -# ----------------------------------------------------------------------------- - - -def log_model_info( - log_file: Optional[str], - full_train_state: train_state_lib.TrainState, - partitioner: partitioning.BasePartitioner, -): - """Log the variable shapes information and optionally write it to a file.""" - # Only write logs on host 0. - if jax.process_index() != 0: - return - - state_dict = full_train_state.state_dict() - total_num_params = jax.tree_util.tree_reduce( - np.add, jax.tree.map(np.size, state_dict['target']) - ) - - logical_axes = partitioner.get_logical_axes(full_train_state).state_dict() - - mesh_axes = jax.tree.map( - lambda x: tuple(x) if x is not None else None, - partitioner.get_mesh_axes(full_train_state).state_dict(), - ) - - def _log_info_and_write_to_file(writer, format_str, *args): - logging.info(format_str, *args) - if writer is not None: - writer.write(format_str % args + '\n') - - with contextlib.ExitStack() as stack: - writer = ( - stack.enter_context(gfile.GFile(log_file, 'w')) - if log_file is not None - else None - ) - - # Log params - def _log_variable( - name: str, - arr: Optional[np.ndarray], - logical_axes: Optional[partitioning.AxisNames], - mesh_axes: Optional[partitioning.PartitionSpec], - ): - # Log nothing on empty dict leaves, which occur with optax EmptyState(). - if isinstance(arr, dict) and not arr: - return - if arr is None: - _log_info_and_write_to_file(writer, 'Variable %-80s None', name) - return - if logical_axes is None or len(logical_axes) != len(arr.shape): - shape_str = str(arr.shape) - else: - shape_str = '({})'.format( - ', '.join( - f'{name}={dimension}' - for name, dimension in zip(logical_axes, arr.shape) - ) - ) - _log_info_and_write_to_file( - writer, - 'Variable %-80s size %-12s shape %-40s partition spec %s', - name, - arr.size, - shape_str, - mesh_axes, - ) - - jax.tree.map( - _log_variable, - state_utils.get_name_tree(state_dict['target'], keep_empty_nodes=True), - state_dict['target'], - logical_axes['target'], - mesh_axes['target'], - ) - - _log_info_and_write_to_file( - writer, 'Total number of parameters: %d', total_num_params - ) - - # Add a blank line between params and states. - _log_info_and_write_to_file(writer, '') - - jax.tree.map( - _log_variable, - state_utils.get_name_tree(state_dict['state'], keep_empty_nodes=True), - state_dict['state'], - logical_axes['state'], - mesh_axes['state'], - ) - - -# ----------------------------------------------------------------------------- -# Utility functions for prediction and evaluation. -# ----------------------------------------------------------------------------- - - -class InferStepWithRngCallable(typing_extensions.Protocol): - - def __call__( - self, - params: Mapping[str, Any], - batch: Mapping[str, jnp.ndarray], - rng: jnp.ndarray = None, # pytype: disable=annotation-type-mismatch # jax-ndarray - flax_mutables: Optional[Mapping[str, Any]] = None, - ) -> PyTree: - """Runs an inference step returning a prediction or score.""" - ... - - -class InferStepWithoutRngCallable(typing_extensions.Protocol): - - def __call__( - self, - params: Mapping[str, Any], - batch: Mapping[str, jnp.ndarray], - flax_mutables: Optional[Mapping[str, Any]] = None, - ) -> PyTree: - """Runs an inference step returning a prediction or score.""" - ... - - -InferStepCallable = Union[InferStepWithRngCallable, InferStepWithoutRngCallable] - -# NOTE: We're not more prescriptive than PyTree because that's what -# InferStepCallable expects. -_InferFnResult = Sequence[Tuple[int, PyTree]] -_InferFnWithAuxResult = Tuple[_InferFnResult, Mapping[str, Sequence[Any]]] - - -class InferFnCallable(typing_extensions.Protocol): - - def __call__( - self, - ds: tf.data.Dataset, - train_state: train_state_lib.TrainState, - rng: Optional[jnp.ndarray] = None, - ) -> Union[_InferFnResult, _InferFnWithAuxResult]: - """Runs inference on the dataset.""" - ... - - -def _remove_padding(all_inferences, all_indices): - """Remove padded examples. - - Args: - all_inferences: PyTree[total_examples + padding_count, ...]. - all_indices: [total_examples + padding_count]. - - Returns: - all_inferences in shape PyTree[total_examples, ...]. - all_indices in shape [total_exmamples]. - """ - non_pad_idxs = np.where(all_indices >= 0) - all_indices = all_indices[non_pad_idxs] - all_inferences = jax.tree.map(lambda x: x[non_pad_idxs], all_inferences) - return all_inferences, all_indices - - -def get_infer_fn( - infer_step: InferStepCallable, - batch_size: int, - train_state_axes: train_state_lib.TrainState, - partitioner: partitioning.BasePartitioner, - keep_aux_as_numpy: bool = False, -) -> InferFnCallable: - """Get prediction function for the SeqIO evaluator. - - The returned prediction function should take in an enumerated dataset, make - predictions and return in an enumerated form with the original indices and - examples zipped together. This ensures that the predictions are compared to - the targets in a correct order even if the dataset is sharded across - multiple hosts and gathered in a nondeterministic way. - - jax.process_index == 0 is used as a "main host", i.e., it gathers all - inference results and returns. - - Shape notation: - Per replica set num replicas: R - Per replica set batch size: B - Number of replica sets: H - Length: L - - Some transformations have shape transformation annotation, e.g., - [B, L] -> [R, B/R, L]. - - Args: - infer_step: a callable that executes one prediction step. Should not yet be - partitioned or pmapped. - batch_size: the number of examples in the global infer batch. - train_state_axes: Partitioning info for the train state object. - partitioner: partitioner to use. - keep_aux_as_numpy: bool. whether to leave aux values as numpy arrays; can be - used to save space when saving bfloat16s - - Returns: - predict_fn: a callable which takes in the enumerated infer dataset and an - optimizer and runs the prediction. - """ - - def infer_step_with_indices(params, batch, rng, indices, flax_mutables): - if 'rng' in inspect.signature(infer_step).parameters: - if 'flax_mutables' in inspect.signature(infer_step).parameters: - res = typing.cast(InferStepWithRngCallable, infer_step)( - params, batch, rng, flax_mutables=flax_mutables - ) - else: - res = typing.cast(InferStepWithRngCallable, infer_step)( - params, batch, rng - ) - else: - if 'flax_mutables' in inspect.signature(infer_step).parameters: - res = typing.cast(InferStepWithoutRngCallable, infer_step)( - params, batch, flax_mutables=flax_mutables - ) - else: - res = typing.cast(InferStepWithoutRngCallable, infer_step)( - params, batch - ) - return indices, res - - partitioned_infer_step = partitioner.partition( - infer_step_with_indices, - in_axis_resources=( - train_state_axes.params, - partitioner.data_partition_spec, - None, - partitioner.data_partition_spec, - train_state_axes.flax_mutables, - ), - out_axis_resources=(None, None), - ) - - data_layout = partitioner.get_data_layout(batch_size) - shard_id = data_layout.shard_id - num_shards = data_layout.num_shards - - per_shard_batch_size = batch_size // num_shards - - def infer_fn( - ds: tf.data.Dataset, - train_state: train_state_lib.TrainState, - rng: Optional[jnp.ndarray] = None, - ): - ds_shapes = jax.tree.map(lambda x: jnp.array(x.shape), ds.element_spec) - multihost_assert_equal( - ds_shapes, - ( - 'Dataset element shapes do not agree across hosts. ' - 'This could be an indication that the dataset is nondeterministic.' - ), - ) - try: - original_ds_length = len(ds) - dataset_remainder = original_ds_length % batch_size # pytype:disable=wrong-arg-types - logging.info('length of dataset = %s', len(ds)) - except TypeError as e: - if str(e).endswith('dataset length is unknown.'): - logging.warning( - 'The following error is likely due to the use of TensorFlow v1 in ' - 'your dataset pipeline. Verify you are not importing from ' - '`tf.compat.v1` as part of your pipeline.' - ) - raise e - - if dataset_remainder: - dataset_pad_amt = batch_size - dataset_remainder - logging.info( - 'Padding infer dataset with %d examples for even per-replica shards.', - dataset_pad_amt, - ) - # Pad with the first example using an index of -1 so seqio will ignore. - pad_ds = ( - ds.take(1) - .map(lambda i, x: (np.int64(-1), x)) - .cache() - .repeat(dataset_pad_amt) - ) - ds = ds.concatenate(pad_ds) - - # Shard the infer dataset across replica sets. - sharded_ds = ds.shard(num_shards, shard_id).batch( - per_shard_batch_size, drop_remainder=True - ) - multihost_assert_equal( - jnp.array(len(sharded_ds)), 'Dataset lengths do not agree across hosts.' - ) - - logging.info( - ( - 'The infer dataset is sharded into %d shards with per-shard ' - 'batch size of %d' - ), - num_shards, - per_shard_batch_size, - ) - - # Run inference for each replica set. - batched_results, all_indices = [], [] - for index, infer_batch in sharded_ds.as_numpy_iterator(): - if rng is None: - step_rng = None - else: - step_rng, rng = jax.random.split(rng) - # Run fast inference on batch. - # [B, ...] -> [B * shard_count, ...] - # partitioned_infer_step executes infer_step on sharded batched data, and - # returns de-sharded batched indices and result replicated on all hosts. - - if jax.process_count() > 1: - infer_batch_p, step_rng_p, index_p = ( - multihost_utils.host_local_array_to_global_array( - (infer_batch, step_rng, index), - partitioner.mesh, - ( - partitioner.data_partition_spec, - # Use empty partition spec instead of None - jax.sharding.PartitionSpec(), - partitioner.data_partition_spec, - ), - ) - ) - batch_indices, batch_result = partitioned_infer_step( - train_state.params, - infer_batch_p, - step_rng_p, - index_p, - train_state.flax_mutables, - ) - batch_indices, batch_result = ( - multihost_utils.global_array_to_host_local_array( - (batch_indices, batch_result), - partitioner.mesh, - # Use empty partition spec instead of None - (jax.sharding.PartitionSpec(), jax.sharding.PartitionSpec()), - ) - ) - else: - batch_indices, batch_result = partitioned_infer_step( - train_state.params, - infer_batch, - step_rng, - index, - train_state.flax_mutables, - ) - logging.info('Inference of batch %s done.', index) - - - def _copy_to_host_async(x): - if hasattr(x, 'addressable_data'): - # Array is fully replicated. - x.addressable_data(0).copy_to_host_async() - return x.addressable_data(0) - else: - x.copy_to_host_async() - return x - - try: - batch_result = jax.tree.map(_copy_to_host_async, batch_result) - batch_indices = jax.tree.map(_copy_to_host_async, batch_indices) - except AttributeError: - # Similar to jax.device_get, we skip transfers for non DeviceArrays. - pass - - batched_results.append(batch_result) - all_indices.append(batch_indices) - - logging.info('Inference of all batches done.') - all_inferences = batched_results - - # List[B * shard_count, ...] -> [B * shard_count * batch_count, ...] - all_inferences = jax.tree.map( - lambda *args: np.concatenate(args), *all_inferences - ) - all_indices = np.concatenate(all_indices) - - all_inferences, all_indices = _remove_padding(all_inferences, all_indices) - - # Results are returned from infer_step out of order due to shard operation. - # Note: remove padding first, as -1 indices would mess up this operation. - # Note: all_inferences may be a PyTree, not just an array, e.g. if - # `infer_step` is `model.predict_batch_with_aux`. - all_inferences = jax.tree.map(lambda x: x[all_indices], all_inferences) - all_indices = all_indices[all_indices] - - # aux_values is supposed to be a dictionary that maps strings to a set of - # auxiliary values. - # - # We don't want to flatten/unflatten the aux values. We want to preserve the - # unflattened values with the type List[Mapping[str, Sequence[Any]]]. We do - # this as a memory optimization to avoid lots of redundant keys if we'd - # instead had List[Mapping[str, Any]]. - # - # It has shape Mapping[str, [B * shard_count * batch_count, ...]]. That is, - # the first dimension of each of the values in aux_values is equal to - # len(all_inferences). - aux_values = None - if ( - isinstance(all_inferences, tuple) - and len(all_inferences) == 2 - and isinstance(all_inferences[1], Mapping) - ): - all_inferences, aux_values = all_inferences - - # Translate to List[...] by flattening inferences making sure to - # preserve structure of individual elements (inferences are not assumed to - # be simple np.array). Finally, zip inferences with corresponding indices - # and convert leaf np.arrays into lists. - all_inferences, struct = jax.tree_util.tree_flatten(all_inferences) - all_inferences = map( - functools.partial(jax.tree_util.tree_unflatten, struct), - zip(*all_inferences), - ) - indices_and_outputs = list(zip(all_indices, all_inferences)) - indices_and_outputs = jax.tree.map( - lambda x: np.array(x).tolist(), indices_and_outputs - ) - if len(indices_and_outputs) != original_ds_length: - raise ValueError( - 'Size of indices_and_outputs does not match length of original ' - 'dataset: %d versus %d' - % (len(indices_and_outputs), original_ds_length) - ) - - if aux_values is None: - return indices_and_outputs - else: - if keep_aux_as_numpy: - aux_values = jax.tree.map(lambda x: list(np.array(x)), aux_values) - else: - aux_values = jax.tree.map(lambda x: np.array(x).tolist(), aux_values) - return indices_and_outputs, aux_values - - return infer_fn - - -# ----------------------------------------------------------------------------- -# SeqIO utility functions. -# ----------------------------------------------------------------------------- - - -def import_module(module: str): - """Imports the given module at runtime.""" - logging.info('Importing %s.', module) - try: - importlib.import_module(module) - except RuntimeError as e: - if ( - str(e) - == 'Attempted to add a new configurable after the config was locked.' - ): - raise RuntimeError( - 'Your Task/Mixture module contains gin configurables that must be ' - 'loaded before gin flag parsing. One fix is to add ' - f"'import {module}' in your gin file." - ) from e - raise e - - -def get_vocabulary( - cfg: DatasetConfig, -) -> Tuple[seqio.Vocabulary, seqio.Vocabulary]: - """Returns `seqio.Vocabulary` objects associated with the `Mixture`/`Task`. - - Args: - cfg: the DatasetConfig specifying which mixture or task to get the - vocabularies for. - - Returns: - A tuple of seqio.Vocabulary for inputs and targets. - - Raises: - ValueError: if inputs and targets are not both present and vocabularies - are different. - """ - if cfg.module: - warnings.warn( - ( - 'The use of `DatasetConfig.module` and `MIXTURE_OR_TASK_MODULE` is ' - 'deprecated in favor of importing the module directly or via gin.' - ), - DeprecationWarning, - ) - import_module(cfg.module) - - if isinstance(cfg.mixture_or_task_name, airio.DatasetProviderBase): - mixture_or_task = cfg.mixture_or_task_name - vocab_map = airio.dataset_providers.get_vocabularies(mixture_or_task) - if not vocab_map: - raise ValueError( - f'No vocabularies found for AirIO task/mixture {mixture_or_task}' - ) - features = {k: seqio.Feature(vocabulary=v) for k, v in vocab_map.items()} # pytype: disable=wrong-arg-types - elif isinstance(cfg.mixture_or_task_name, seqio.DatasetProviderBase): - mixture_or_task = cfg.mixture_or_task_name - features = mixture_or_task.output_features - else: - mixture_or_task = seqio.get_mixture_or_task(cfg.mixture_or_task_name) - features = mixture_or_task.output_features - - if 'inputs' in features and 'targets' in features: - return (features['inputs'].vocabulary, features['targets'].vocabulary) - - # If a mix of PassThroughVocabularies and other Vocabularies are specified, - # use the non-PassThroughVocabularies. - # TODO(b/185912004): Remove this once a more general solution is implemented. - vocabularies = list( - f.vocabulary - for f in features.values() - if not isinstance(f.vocabulary, seqio.PassThroughVocabulary) - ) - - # Otherwise, if all of the vocabs are PassThroughVocabularies, use those. - if not vocabularies: - vocabularies = list(f.vocabulary for f in features.values()) - - # If there still aren't any vocabularies, raise an error. - if not vocabularies: - raise ValueError( - '"inputs" and "targets" are not both present, and ' - 'no vocabularies were set for any features.' - ) - - first_vocab = vocabularies[0] - for vocab in vocabularies[1:]: - if vocab != first_vocab: - raise ValueError( - '"inputs" and "targets" are not both present, and ' - 'vocabularies are different.' - ) - return (first_vocab, first_vocab) - - -def verify_matching_vocabs( - cfg: DatasetConfig, model: models.BaseTransformerModel -): - """Verify whether the task vocab matches the model vocab. - - The seqio Task and the Model both define their vocabularies - separately, but these vocabularies must match or else the training/inference - results will not be sensible. This functions validates that they do match, - under the assumption that this is a standard Encoder-only, Decoder-only, - or Encoder-decoder model. - - Args: - cfg: The DatasetConfig of the training/inference task. - model: A BaseTransformerModel model with input_vocabulary and - output_vocabulary attributes. - - Raises: - ValueError: If the task vocabulary does not match the model vocabulary. - """ - ds_vocabs = get_vocabulary(cfg) - if ( - ds_vocabs[0] != model.input_vocabulary - or ds_vocabs[1] != model.output_vocabulary - ): - raise ValueError( - 'Model and Task vocabularies do not match:\n' - f' task={cfg.mixture_or_task_name}\n' - f' ds_vocabs=({ds_vocabs[0]}, {ds_vocabs[1]})\n' - f' model.input_vocabulary={model.input_vocabulary}\n' - f' model.output_vocabulary={model.output_vocabulary}\n' - ) - - - - -def get_dataset( - cfg: DatasetConfig, - shard_id: int, - num_shards: int, - feature_converter_cls: Callable[..., seqio.FeatureConverter], - num_epochs: Optional[int] = None, - continue_from_last_checkpoint: bool = False, -) -> Union[tf.data.Dataset, clu.data.dataset_iterator.DatasetIterator]: - """Returns a dataset from SeqIO based on a `DatasetConfig`.""" - if continue_from_last_checkpoint: - raise ValueError( - '`continue_from_last_checkpoint` must be set to False as this is not ' - 'supported by this dataset fn.' - ) - del continue_from_last_checkpoint - - if cfg.module: - import_module(cfg.module) - - if cfg.batch_size % num_shards: - raise ValueError( - f'Batch size ({cfg.batch_size}) must be divisible by number of ' - f'shards ({num_shards}).' - ) - - seed = cfg.seed - - - shard_info = seqio.ShardInfo(index=shard_id, num_shards=num_shards) - - if seed is None: - # Use a shared timestamp across devices as the seed. - seed = int(multihost_utils.broadcast_one_to_all(np.int32(time.time()))) - - return get_dataset_inner( - cfg, shard_info, feature_converter_cls, seed, num_epochs - ) - - -def get_dataset_inner( - cfg: DatasetConfig, - shard_info: seqio.ShardInfo, - feature_converter_cls: Callable[..., seqio.FeatureConverter], - seed: Optional[int] = None, - num_epochs: Optional[int] = None, -): - """Internal fn to load a dataset from SeqIO based on a `DatasetConfig`.""" - batch_size = cfg.batch_size // shard_info.num_shards - if isinstance( - cfg.mixture_or_task_name, - (seqio.DatasetProviderBase, airio.DatasetProviderBase), - ): - mixture_or_task = cfg.mixture_or_task_name - else: - mixture_or_task = seqio.get_mixture_or_task(cfg.mixture_or_task_name) - if seed is not None: - if not str(jax.devices()[0]).startswith('MOCK_TPU'): - multihost_assert_equal( - np.array(seed), - ( - f'`seed` is not same across hosts; {jax.process_index} has a seed' - f' of {seed}' - ), - ) - logging.info( - ( - "Initializing dataset for task '%s' with a replica batch size of %d" - ' and a seed of %d' - ), - mixture_or_task.name, - batch_size, - seed, - ) - - in_memory_shuffle = cfg.shuffle - return seqio.get_dataset( - mixture_or_task_name=mixture_or_task, - task_feature_lengths=cfg.task_feature_lengths, - dataset_split=cfg.split, - shuffle=in_memory_shuffle, - num_epochs=num_epochs, - feature_converter=feature_converter_cls( - pack=cfg.pack, use_custom_packing_ops=cfg.use_custom_packing_ops - ), - shard_info=shard_info, - use_cached=cfg.use_cached, - seed=seed, - trim_output_features=cfg.trim_output_features, - batch_size=batch_size, - ) - - -class GetDatasetCallable(typing_extensions.Protocol): - """Interface for a function returning a dataset (iterator).""" - - def __call__( - self, - cfg: DatasetConfig, - shard_id: int, - num_shards: int, - feature_converter_cls: Callable[..., seqio.FeatureConverter], - num_epochs: Optional[int] = None, - continue_from_last_checkpoint: bool = True, - ) -> Union[ - clu.data.dataset_iterator.DatasetIterator, - tf.data.Dataset, - ]: - ... - - -class GetEvalDatasetCallable(typing_extensions.Protocol): - """Interface for a function returning a dataset (iterator).""" - - def __call__( - self, - cfg: DatasetConfig, - shard_id: int, - num_shards: int, - eval_steps: int, - feature_converter_cls: Callable[..., seqio.FeatureConverter], - ) -> Mapping[str, Union[tf.data.Dataset, airio.AirIODatasetIterator]]: - ... - - -def get_training_eval_datasets( - cfg: DatasetConfig, - shard_id: int, - num_shards: int, - eval_steps: int, - feature_converter_cls: Callable[..., seqio.FeatureConverter], - deterministic: bool = False, - model_dir: Optional[str] = None, - start_step: int = 0, -) -> Mapping[ - str, - Union[ - tf.data.Dataset, - clu.data.dataset_iterator.DatasetIterator, - Iterable[Mapping[str, np.ndarray]], - ], -]: - """Returns a mapping from eval task name to its dataset.""" - if isinstance( - cfg.mixture_or_task_name, - (seqio.DatasetProviderBase, airio.DatasetProviderBase), - ): - mixture_or_task = cfg.mixture_or_task_name - else: - mixture_or_task = seqio.get_mixture_or_task(cfg.mixture_or_task_name) - datasets = {} - get_dataset_fn = get_dataset - if deterministic: - assert model_dir is not None - get_dataset_fn = functools.partial( - get_deterministic_dataset, model_dir=model_dir, start_step=start_step - ) - - if isinstance(mixture_or_task, (airio.Task, airio.Mixture)): - data_iter = get_dataset_fn( - dataclasses.replace(cfg, batch_size=1), - shard_id=0, - num_shards=1, - feature_converter_cls=feature_converter_cls, - num_epochs=eval_steps * cfg.batch_size, - continue_from_last_checkpoint=False, - ) - # TODO(b/304579895): Cannot use itertools.islice here to limit the number - # of records to eval_steps since peek() is required later. Instead, update - # t5x eval to not depend on the data pipeline and stop after eval_run - # steps. - datasets[mixture_or_task.name] = data_iter - return datasets - - if cfg.batch_size % num_shards: - raise ValueError( - f'Batch size ({cfg.batch_size}) must be divisible by number of ' - f'shards ({num_shards}).' - ) - - def _repeat_shard_batch_take_cache(ds: tf.data.Dataset): - # We shard and batch the full, repeated dataset to avoid issues with uneven - # file shards. - if not isinstance(ds, tf.data.Dataset): - raise ValueError('Only tf.data.Dataset objects supported.') - ds = ( - ds.unbatch() - .repeat() - .shard(num_shards, shard_id) - .batch(cfg.batch_size // num_shards, drop_remainder=True) - .take(eval_steps) - ) - if cfg.use_memory_cache: - return ds.cache() - else: - return ds - - for task in seqio.get_subtasks(mixture_or_task): - if cfg.split not in task.splits: - logging.info( - "Task %s has no '%s' split; skipping training evaluation.", - task.name, - cfg.split, - ) - continue - logging.info('Loading task %s for training evaluation.', task.name) - task_cfg = dataclasses.replace( - cfg, mixture_or_task_name=task.name, batch_size=1 - ) - # We set `num_epochs` to be finite to avoid infinite loops on shards that - # have input examples that are all filtered. - datasets[task.name] = _repeat_shard_batch_take_cache( - get_dataset_fn( - task_cfg, - shard_id=0, - num_shards=1, - feature_converter_cls=feature_converter_cls, - num_epochs=eval_steps * cfg.batch_size, - continue_from_last_checkpoint=False, - ) - ) - - if isinstance(mixture_or_task, seqio.Mixture): - datasets[mixture_or_task.name] = _repeat_shard_batch_take_cache( - get_dataset_fn( - dataclasses.replace(cfg, batch_size=1), - shard_id=0, - num_shards=1, - feature_converter_cls=feature_converter_cls, - num_epochs=eval_steps * cfg.batch_size, - continue_from_last_checkpoint=False, - ) - ) - - return datasets - - -def round_vocab_size_to_multiple( - vocabulary: seqio.Vocabulary, divisor: int = 128 -): - """Round up vocabulary size for improved TPU performance.""" - size = vocabulary.vocab_size - return size + -size % divisor - - -def flatten_dict_string_keys(x): - """Flattens a nested dictionary to have string keys and '/' separators.""" - return traverse_util.flatten_dict(flax.core.unfreeze(x), sep='/') - - -class _RegexMap(collections.abc.Mapping): - """Ordered mapping from regexes to values requiring a full match.""" - - def __init__(self, kvs: Sequence[Tuple[str, Any]]): - self._kvs = [(re.compile(k), v) for k, v in kvs] - - def __getitem__(self, key: str) -> Any: - for pattern, v in self._kvs: - if pattern.fullmatch(key): - return v - raise KeyError(f'No pattern matching key: {key}') - - def __len__(self) -> int: - return len(self._kvs) - - def __iter__(self) -> Iterable[Tuple[re.Pattern[str], Any]]: - return iter(self._kvs) - - -def override_params_axes_names( - model_variables: flax_scope.FrozenVariableDict, - params_axes_names_override: Sequence[Tuple[str, Tuple[str, ...]]] = (), -) -> flax_scope.FrozenVariableDict: - """Applies parameter axis names overrides to axes variables. - - Args: - model_variables: the original model variables containing the 'params_axes' - collection. - params_axes_names_override: a priority-ordered mapping from regex patterns - (fully matching parameter names) to tuples containing string logical axis - names to replace model-derived names. - - Returns: - an updated set of model variables with the overrides applied to the - 'params_axes' collection. - """ - params_axes_names_override_map = _RegexMap(params_axes_names_override) - - if 'params_axes' not in model_variables: - raise ValueError( - "Model variables do not contain a 'params_axes' collection to apply an " - 'override to.' - ) - model_variables = flax.core.unfreeze(model_variables) - flat_params = traverse_util.flatten_dict(model_variables['params']) - flat_params_axes = traverse_util.flatten_dict(model_variables['params_axes']) - - for key, param in flat_params.items(): - param_name = '/'.join(key) - override = params_axes_names_override_map.get(param_name) - if override is None: - continue - - param_axes_key = key[:-1] + (f'{key[-1]}_axes',) - - curr_metadata = flat_params_axes.get(param_axes_key) - - if curr_metadata is None: - logging.info('Adding axis names for %s: %s', param_name, override) - else: - assert isinstance(curr_metadata, flax_partitioning.AxisMetadata) - logging.info( - 'Replacing axis names for %s (%s) with %s.', - param_name, - curr_metadata.names, - override, - ) - - if param.ndim != len(override): - raise ValueError( - f'Provided axis name override for {param_name} does not match ' - f'param rank ({param.ndim}): {override}' - ) - flat_params_axes[param_axes_key] = flax_partitioning.AxisMetadata( - names=override - ) - - model_variables['params_axes'] = traverse_util.unflatten_dict( - flat_params_axes - ) - return flax.core.freeze(model_variables) - - -def find_next_checkpoint_step( - checkpoint_steps_index: int, - inner_num_steps: int, - is_checkpoint_step: bool, - host_step: int, - checkpoint_steps: Sequence[int], - epoch_end_step: int, - checkpoint_period: int, - first_step: int, -): - """Finds next valid checkpoint step in checkpoint_steps list parameter. - - This will stop scalar training and save a checkpoint at that step. - Checkpoint step is considered valid if less than or equal to epoch end step, - not at a concurrent checkpoint_period step, and greater than the first - step. Specifically, this will decrease inner_num_steps (the current number - of scalar steps to train) if host_step + inner_num_steps would otherwise - pass a desired checkpoint step. - - Args: - checkpoint_steps_index: Current index in checkpoint_steps list while - training. - inner_num_steps: Number of scalar steps to iterate through in training. - is_checkpoint_step: Dictates whether the current subset of scalar steps - contains a valid checkpoint_step to save. - host_step: Host step of training. - checkpoint_steps: List of checkpoint_stems passed in as parameter in - checkpoint_cfg.save. - epoch_end_step: Last training step in epoch. - checkpoint_period: Period value passed in as parameter in - checkpoint_cfg.save. - first_step: First step while training. Or, this may be the first step after - resuming a job (e.g. after pre-emption). - - Returns: - Tuple containing (possibly) halted inner_num_steps value and - is_checkpoint_step if checkpoint step value was found. - """ - # checks to see if inner_num_steps + host_step will out-pace current - # checkpoint steps index that should be saved - while ( - checkpoint_steps - and (inner_num_steps + host_step) - >= checkpoint_steps[checkpoint_steps_index] - ): - curr_checkpoint_step = checkpoint_steps[checkpoint_steps_index] - if ( - ((curr_checkpoint_step - first_step) % checkpoint_period != 0) - and curr_checkpoint_step > first_step - and curr_checkpoint_step <= epoch_end_step - ): - # Found a checkpoint step before inner_num_steps + host_step that - # is not part of the normal period. - inner_num_steps = curr_checkpoint_step - host_step - is_checkpoint_step = True - return (inner_num_steps, is_checkpoint_step) - if ( - not is_checkpoint_step - and checkpoint_steps_index == len(checkpoint_steps) - 1 - ): - # Iterated through all of checkpoint_steps, and there are no new - # checkpoint steps before inner_num_steps + host_step. - return (inner_num_steps, is_checkpoint_step) - checkpoint_steps_index += 1 - # Reached inner_num_steps + host_step and no new checkpoint steps found. - # Return the original settings. - return (inner_num_steps, is_checkpoint_step) - - -def find_first_checkpoint_step( - checkpoint_steps_index: int, - checkpoint_steps: Sequence[int], - first_step: int, - host_step: int, -) -> int: - """Finds the first valid step in checkpoint_step list parameter to save a checkpoint at. - - Args: - checkpoint_steps_index: Current index in checkpoint_steps list while - training. - checkpoint_steps: List of checkpoint_stems passed in as parameter in - checkpoint_cfg.save. - first_step: First step in epoch while training. - host_step: Host step of training. - - Returns: - Integer containing first valid checkpoint step index after host_step and not - equal to first_step. - """ - while ( - checkpoint_steps - and checkpoint_steps_index < len(checkpoint_steps) - 1 - and ( - host_step > checkpoint_steps[checkpoint_steps_index] - or checkpoint_steps[checkpoint_steps_index] == first_step - ) - ): - checkpoint_steps_index += 1 - return checkpoint_steps_index - - - - -def run_main(main, flags_parser): - app.run(main, flags_parser=flags_parser) diff --git a/t5x-main/t5x/utils_test.py b/t5x-main/t5x/utils_test.py deleted file mode 100644 index edcb464ae6e96570a52cdc41b568beb10ffa0ed3..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/utils_test.py +++ /dev/null @@ -1,1114 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for t5x.utils.""" - -import dataclasses -import functools -import os -import re -from typing import Mapping, Optional -import unittest - -from absl import flags -from absl.testing import absltest -from absl.testing import parameterized -import airio.core as airio -import airio.pygrain_common as airio_common -import flax.core -from flax.linen import partitioning as flax_partitioning -import jax -from jax.experimental.pjit import pjit -import numpy as np -import seqio -from t5x import checkpoints -from t5x import partitioning -from t5x import test_utils -from t5x import train_state as train_state_lib -from t5x import utils -import tensorflow as tf - -mock = absltest.mock -Evaluator = seqio.Evaluator -PartitionSpec = partitioning.PartitionSpec -AxisMetadata = flax_partitioning.AxisMetadata -FlaxMutables = flax.core.FrozenDict - -# Parse absl flags test_srcdir and test_tmpdir. -jax.config.parse_flags_with_absl() - -FLAGS = flags.FLAGS - - -def get_mock_train_state(params, param_states=None, step=0, flax_mutables=None): - """Returns a mock TrainState.""" - step = np.array(step) if step is not None else None - state = mock.Mock(param_states=param_states, step=step) - state_dict = dict( - target=params, state=dict(param_states=param_states, step=step) - ) - args = dict( - params=params, - param_states=param_states, - step=step, - state_dict=lambda: state_dict, - optimizer=mock.Mock( - target=params, state=state, state_dict=lambda: state_dict - ), - flax_mutables=flax_mutables, - ) - return mock.Mock(**args) - - -class UtilsTest(parameterized.TestCase): - - def round_vocab_size_to_multiple(self): - self.assertEqual(utils.round_vocab_size_to_multiple(1), 128) - self.assertEqual(utils.round_vocab_size_to_multiple(128), 128) - self.assertEqual(utils.round_vocab_size_to_multiple(129), 256) - self.assertEqual(utils.round_vocab_size_to_multiple(129), 256) - self.assertEqual( - utils.round_vocab_size_to_multiple(25600, divisor=384), 256128 - ) - - def test_get_zeros_batch_like_spec(self): - test_utils.assert_same( - utils.get_zeros_batch_like_spec({ - "i": jax.ShapeDtypeStruct((2, 5), dtype=np.int32), - "j": jax.ShapeDtypeStruct((1,), dtype=np.float32), - }), - { - "i": np.zeros((2, 5), dtype=np.int32), - "j": np.zeros((1,), dtype=np.float32), - }, - ) - - def test_get_zeros_batch_like_dataset(self): - ds = tf.data.Dataset.from_tensors({ - "i": np.arange(10, dtype=np.int32).reshape((2, 5)), - "j": np.ones((1,), dtype=np.float32), - }) - - test_utils.assert_same( - utils.get_zeros_batch_like_dataset(ds), - { - "i": np.zeros((2, 5), dtype=np.int32), - "j": np.zeros((1,), dtype=np.float32), - }, - ) - - test_utils.assert_same( - utils.get_zeros_batch_like_dataset(ds, batch_size=4), - { - "i": np.zeros((4, 5), dtype=np.int32), - "j": np.zeros((4,), dtype=np.float32), - }, - ) - - @parameterized.named_parameters( - dict(testcase_name="write_to_file", write_to_log_file=True), - dict(testcase_name="do_not_write_to_file", write_to_log_file=False), - ) - def test_log_model_info(self, write_to_log_file): - log_file = self.create_tempfile() if write_to_log_file else None - - mock_train_state = get_mock_train_state( - params={ - "a": {"aa": jax.ShapeDtypeStruct(shape=(2, 3), dtype=np.int32)}, - "c": jax.ShapeDtypeStruct(shape=(7, 8), dtype=np.int32), - }, - param_states={ - "a": { - "aa": { - "v_row": jax.ShapeDtypeStruct(shape=(2,), dtype=np.int32), - "v_col": jax.ShapeDtypeStruct(shape=(3,), dtype=np.int32), - } - }, - "c": { - "v_row": jax.ShapeDtypeStruct(shape=(2, 4), dtype=np.int32), - "v_col": None, - }, - }, - ) - - mock_logical_axes = get_mock_train_state( - params={ - "a": {"aa": partitioning.AxisNames("a1", None)}, - "c": partitioning.AxisNames(None, "a1"), - }, - param_states={ - "a": { - "aa": { - "v_row": partitioning.AxisNames( - None, - ), - "v_col": partitioning.AxisNames( - None, - ), - } - }, - "c": { - "v_row": partitioning.AxisNames( - "a1", - ), - "v_col": partitioning.AxisNames( - "a2", - ), - }, - }, - step=None, - ) - - mock_mesh_axes = get_mock_train_state( - params={ - "a": {"aa": PartitionSpec("b1", None)}, - "c": PartitionSpec(None, "b1"), - }, - param_states={ - "a": { - "aa": { - "v_row": partitioning.AxisNames( - None, - ), - "v_col": partitioning.AxisNames( - None, - ), - } - }, - "c": { - "v_row": partitioning.AxisNames( - "b1", - ), - "v_col": partitioning.AxisNames( - "b2", - ), - }, - }, - step=None, - ) - - partitioner = mock.Mock( - get_logical_axes=lambda _: mock_logical_axes, - get_mesh_axes=lambda _: mock_mesh_axes, - ) - - with self.assertLogs(level="INFO") as logs: - utils.log_model_info( - log_file and log_file.full_path, mock_train_state, partitioner - ) - - relevant_logs = [ - re.sub(r"\s+", " ", output) - for record, output in zip(logs.records, logs.output) - if "t5x/utils.py" in record.pathname - ] - self.assertLen(relevant_logs, 9) - self.assertIn( - "Variable a/aa size 6 shape (a1=2, None=3) partition spec ('b1', None)", - relevant_logs[0], - ) - self.assertIn( - "Variable c size 56 shape (None=7, a1=8) partition spec (None, 'b1')", - relevant_logs[1], - ) - - if write_to_log_file: - self.assertEqual( - re.sub(r"\s+", " ", log_file.read_text()), - ( - "Variable a/aa size 6 shape (a1=2, None=3) partition spec ('b1'," - " None) Variable c size 56 shape (None=7, a1=8) partition spec" - " (None, 'b1') Total number of parameters: 62 Variable" - " param_states/a/aa/v_col size 3 shape (None=3) partition spec" - " (None,) Variable param_states/a/aa/v_row size 2 shape (None=2)" - " partition spec (None,) Variable param_states/c/v_col None" - " Variable param_states/c/v_row size 8 shape (2, 4) partition" - " spec ('b1',) Variable step size 1 shape () partition spec None " - ), - ) - - @mock.patch.object(utils, "get_dataset") - def test_get_training_eval_datasets_task(self, mock_get_dataset): - task = mock.create_autospec(seqio.Task, instance=True) - task.name = "mock_task" - task.splits = set(["train", "test"]) - seqio.TaskRegistry.add_provider("mock_task", task) - - mock_get_dataset.return_value = tf.data.Dataset.range(10).batch(1) - mock_fc_cls = mock.Mock() - - cfg = utils.DatasetConfig( - mixture_or_task_name="mock_task", - task_feature_lengths={}, - split="test", - batch_size=4, - shuffle=False, - seed=None, - ) - - # Single shard. - ds = utils.get_training_eval_datasets( - cfg, - shard_id=0, - num_shards=1, - eval_steps=3, - feature_converter_cls=mock_fc_cls, - ) - - mock_get_dataset.assert_called_once_with( - dataclasses.replace(cfg, batch_size=1), - shard_id=0, - num_shards=1, - feature_converter_cls=mock_fc_cls, - num_epochs=12, - continue_from_last_checkpoint=False, - ) - - self.assertSameElements(ds.keys(), ["mock_task"]) - jax.tree.map( - np.testing.assert_equal, - list(ds["mock_task"]), - [ - np.array([0, 1, 2, 3]), - np.array([4, 5, 6, 7]), - np.array([8, 9, 0, 1]), - ], - ) - - # 2 shards, shard 0 - mock_get_dataset.reset_mock() - ds = utils.get_training_eval_datasets( - cfg, - shard_id=0, - num_shards=2, - eval_steps=3, - feature_converter_cls=mock_fc_cls, - ) - - # Call the underlying function loading all shards since the fn shards at the - # example level. - mock_get_dataset.assert_called_once_with( - dataclasses.replace(cfg, batch_size=1), - shard_id=0, - num_shards=1, - feature_converter_cls=mock_fc_cls, - num_epochs=12, - continue_from_last_checkpoint=False, - ) - - self.assertSameElements(ds.keys(), ["mock_task"]) - jax.tree.map( - np.testing.assert_equal, - list(ds["mock_task"]), - [ - np.array([0, 2]), - np.array([4, 6]), - np.array([8, 0]), - ], - ) - - # 2 shards, shard 1 - mock_get_dataset.reset_mock() - ds = utils.get_training_eval_datasets( - cfg, - shard_id=1, - num_shards=2, - eval_steps=3, - feature_converter_cls=mock_fc_cls, - ) - - # Call the underlying function loading all shards since the fn shards at the - # example level. - mock_get_dataset.assert_called_once_with( - dataclasses.replace(cfg, batch_size=1), - shard_id=0, - num_shards=1, - feature_converter_cls=mock_fc_cls, - num_epochs=12, - continue_from_last_checkpoint=False, - ) - - self.assertSameElements(ds.keys(), ["mock_task"]) - jax.tree.map( - np.testing.assert_equal, - list(ds["mock_task"]), - [ - np.array([1, 3]), - np.array([5, 7]), - np.array([9, 1]), - ], - ) - - # 3 shards - with self.assertRaisesWithLiteralMatch( - ValueError, "Batch size (4) must be divisible by number of shards (3)." - ): - _ = utils.get_training_eval_datasets( - cfg, - shard_id=0, - num_shards=3, - eval_steps=3, - feature_converter_cls=mock_fc_cls, - ) - - @mock.patch.object(utils, "get_dataset") - def test_get_training_eval_datasets_mixture(self, mock_get_dataset): - # Register a mock SeqIO mixture. - task1 = mock.create_autospec(seqio.Task, instance=True) - task1.name = "mock_task1" - task1.splits = set(["train", "test"]) - task2 = mock.create_autospec(seqio.Task, instance=True) - task2.name = "mock_task2" - task2.splits = set(["train", "test"]) - seqio.TaskRegistry.add_provider("mock_task1", task1) - seqio.TaskRegistry.add_provider("mock_task2", task2) - mixture = seqio.Mixture( - "mock_mix", ["mock_task1", "mock_task2"], default_rate=1.0 - ) - seqio.MixtureRegistry.add_provider("mock_mix", mixture) - - mock_get_dataset.return_value = tf.data.Dataset.range(10).batch(1) - - # Verify calls to utils.get_dataset - cfg = utils.DatasetConfig( - mixture_or_task_name="mock_mix", - task_feature_lengths={}, - split="test", - batch_size=4, - shuffle=False, - seed=23, - ) - - res = utils.get_training_eval_datasets( - cfg, - shard_id=0, - num_shards=2, - eval_steps=3, - feature_converter_cls=seqio.FeatureConverter, - ) - - expected_calls = [ - mock.call( - dataclasses.replace( - cfg, mixture_or_task_name="mock_task1", batch_size=1 - ), - shard_id=0, - num_shards=1, - feature_converter_cls=seqio.FeatureConverter, - continue_from_last_checkpoint=False, - num_epochs=12, - ), - mock.call( - dataclasses.replace( - cfg, mixture_or_task_name="mock_task2", batch_size=1 - ), - shard_id=0, - num_shards=1, - feature_converter_cls=seqio.FeatureConverter, - continue_from_last_checkpoint=False, - num_epochs=12, - ), - mock.call( - dataclasses.replace( - cfg, mixture_or_task_name="mock_mix", batch_size=1 - ), - shard_id=0, - num_shards=1, - feature_converter_cls=seqio.FeatureConverter, - continue_from_last_checkpoint=False, - num_epochs=12, - ), - ] - mock_get_dataset.assert_has_calls(expected_calls) - - self.assertSameElements( - res.keys(), ["mock_task1", "mock_task2", "mock_mix"] - ) - for ds in res.values(): - jax.tree.map( - np.testing.assert_equal, - list(ds), - [ - np.array([0, 2]), - np.array([4, 6]), - np.array([8, 0]), - ], - ) - - @mock.patch.object(utils, "get_dataset") - def test_get_training_eval_datasets_mixture_obj(self, mock_get_dataset): - # Verify calls to utils.dataset using seqio.Task or seqio.Mixture - # Register a mock SeqIO mixture. - task3 = mock.create_autospec(seqio.Task, instance=True) - task3.name = "mock_task3" - task3.splits = set(["train", "test"]) - task4 = mock.create_autospec(seqio.Task, instance=True) - task4.name = "mock_task4" - task4.splits = set(["train", "test"]) - seqio.TaskRegistry.add_provider("mock_task3", task3) - seqio.TaskRegistry.add_provider("mock_task4", task4) - mixture = seqio.Mixture( - "mock_mix2", ["mock_task3", "mock_task4"], default_rate=1.0 - ) - seqio.MixtureRegistry.add_provider("mock_mix2", mixture) - - mock_get_dataset.return_value = tf.data.Dataset.range(10).batch(1) - cfg_obj = utils.DatasetConfig( - mixture_or_task_name=mixture, - task_feature_lengths={}, - split="test", - batch_size=4, - shuffle=False, - seed=23, - ) - - res_obj = utils.get_training_eval_datasets( - cfg_obj, - shard_id=0, - num_shards=2, - eval_steps=3, - feature_converter_cls=seqio.FeatureConverter, - ) - - expected_calls = expected_calls = [ - mock.call( - dataclasses.replace( - cfg_obj, mixture_or_task_name="mock_task3", batch_size=1 - ), - shard_id=0, - num_shards=1, - feature_converter_cls=seqio.FeatureConverter, - continue_from_last_checkpoint=False, - num_epochs=12, - ), - mock.call( - dataclasses.replace( - cfg_obj, mixture_or_task_name="mock_task4", batch_size=1 - ), - shard_id=0, - num_shards=1, - feature_converter_cls=seqio.FeatureConverter, - continue_from_last_checkpoint=False, - num_epochs=12, - ), - mock.call( - dataclasses.replace(cfg_obj, batch_size=1), - shard_id=0, - num_shards=1, - feature_converter_cls=seqio.FeatureConverter, - continue_from_last_checkpoint=False, - num_epochs=12, - ), - ] - - mock_get_dataset.assert_has_calls(expected_calls) - - self.assertSameElements( - res_obj.keys(), ["mock_task3", "mock_task4", "mock_mix2"] - ) - for ds in res_obj.values(): - jax.tree.map( - np.testing.assert_equal, - list(ds), - [ - np.array([0, 2]), - np.array([4, 6]), - np.array([8, 0]), - ], - ) - - def test_override_params_axes_names(self): - model_variables = flax.core.freeze({ - "params": { - "logits_dense": np.zeros((2, 4)), - "mlp": { - "wo": { - "kernel": np.zeros((4, 6)), - "bias": np.zeros(6), - } - }, - }, - "params_axes": { - "logits_dense_axes": AxisMetadata(names=("vocab", "embed")), - "mlp": { - "wo": {"kernel_axes": AxisMetadata(names=("embed", "mlp"))} - }, - }, - }) - - with self.assertRaisesWithLiteralMatch( - ValueError, - ( - "Model variables do not contain a 'params_axes' collection to apply" - " an override to." - ), - ): - utils.override_params_axes_names( - {"params": model_variables["params"]}, [("mlp/wo/kernel", ("embed",))] - ) - - with self.assertRaisesWithLiteralMatch( - ValueError, - ( - "Provided axis name override for mlp/wo/kernel does not match param" - " rank (2): ('embed',)" - ), - ): - utils.override_params_axes_names( - model_variables, [("mlp/wo/kernel", ("embed",))] - ) - - overridden_variables = utils.override_params_axes_names( - model_variables, - [ - ("wo/kernel", ("batch",)), # unused since not a full match - (".*/wo/kernel", ("batch", "embed")), # this one is used - ("mlp/wo/kernel", ("embed",)), # unused since already matched - ("mlp/wo/bias", ("embed",)), # used - ], - ) - - jax.tree.map( - np.testing.assert_equal, - overridden_variables, - flax.core.freeze({ - "params": { - "logits_dense": np.zeros((2, 4)), - "mlp": { - "wo": { - "kernel": np.zeros((4, 6)), - "bias": np.zeros(6), - } - }, - }, - "params_axes": { - "logits_dense_axes": AxisMetadata(names=("vocab", "embed")), - "mlp": { - "wo": { - "kernel_axes": AxisMetadata(names=("batch", "embed")), - "bias_axes": AxisMetadata(names=("embed",)), - } - }, - }, - }), - ) - - def test_create_orbax_checkpoint_manager_validate(self): - directory = "path/to/dir" - path = os.path.join(directory, "checkpoint") - save_cfg = utils.SaveCheckpointConfig( - checkpoint_manager_cls=checkpoints.OrbaxCheckpointManagerInterface - ) - restore_cfg = utils.RestoreCheckpointConfig( - path=path, checkpoint_manager_cls=None - ) - with self.assertRaises(ValueError): - utils.create_orbax_checkpoint_manager( - save_cfg=save_cfg, - restore_cfg=restore_cfg, - train_state=mock.Mock(), - partitioner=mock.Mock(), - model_dir=directory, - ) - - @parameterized.named_parameters( - dict( - testcase_name="using_best_checkpoint_manager", - checkpointer_cls=checkpoints.Checkpointer, - checkpoint_manager_cls=checkpoints.OrbaxCheckpointManagerInterface, - metrics=None, - expected_metric_mode="max", - expected_force_keep_period=None, - expected_keep_checkpoints_without_metrics=True, - expected_metric=None, - ), - dict( - testcase_name="using_functools_partial_for_checkpoint_manager_cls", - checkpointer_cls=checkpoints.Checkpointer, - checkpoint_manager_cls=functools.partial( - checkpoints.OrbaxCheckpointManagerInterface, - metric_name_to_monitor="loss", - metric_mode="min", - force_keep_period=10, - keep_checkpoints_without_metrics=False, - ), - metrics={"accuracy": 0.8, "loss": 0.1}, - expected_metric_mode="min", - expected_force_keep_period=10, - expected_keep_checkpoints_without_metrics=False, - expected_metric=0.1, - ), - dict( - testcase_name="using_save_best_checkpointer", - checkpointer_cls=checkpoints.SaveBestCheckpointer, - checkpoint_manager_cls=checkpoints.OrbaxCheckpointManagerInterface, - metrics={"train/accuracy": 0.8, "train/loss": 0.1}, - expected_metric_mode="max", - expected_force_keep_period=None, - expected_keep_checkpoints_without_metrics=True, - expected_metric=0.8, - ), - dict( - testcase_name="using_functools_partial_for_save_best_checkpointer", - checkpointer_cls=functools.partial( - checkpoints.SaveBestCheckpointer, - metric_name_to_monitor="train/loss", - metric_mode="min", - force_keep_period=20, - keep_checkpoints_without_metrics=False, - ), - checkpoint_manager_cls=checkpoints.OrbaxCheckpointManagerInterface, - metrics={"train/accuracy": 0.8, "train/loss": 0.1}, - expected_metric_mode="min", - expected_force_keep_period=20, - expected_keep_checkpoints_without_metrics=False, - expected_metric=0.1, - ), - ) - def test_create_orbax_checkpoint_manager( - self, - checkpointer_cls, - checkpoint_manager_cls, - metrics: Mapping[str, float], - expected_metric_mode: str, - expected_force_keep_period: int, - expected_keep_checkpoints_without_metrics: bool, - expected_metric: float, - ): - directory = self.create_tempdir(name="all_checkpoints") - path = os.path.join(directory, "checkpoint") - mock_data_layout = mock.Mock( - shard_id=0, num_shards=1, is_first_host_in_replica_set=True - ) - mock_partitioner = mock.Mock(get_data_layout=lambda: mock_data_layout) - save_cfg = utils.SaveCheckpointConfig( - checkpointer_cls=checkpointer_cls, - checkpoint_manager_cls=checkpoint_manager_cls, - dtype="float32", - keep=5, - period=2, - save_dataset=True, - ) - restore_cfg = utils.RestoreCheckpointConfig( - path=path, - checkpointer_cls=checkpointer_cls, - checkpoint_manager_cls=checkpoint_manager_cls, - dtype="bfloat16", - restore_dataset=False, - ) - - global_mesh = test_utils.create_global_mesh((4, 2), ("x", "y")) - mesh_axes = partitioning.PartitionSpec("x", "y") - global_input_shape = (8, 2) - - train_state = test_utils.make_train_state( - global_mesh, global_input_shape, mesh_axes - ) - - manager = utils.create_orbax_checkpoint_manager( - save_cfg=save_cfg, - restore_cfg=restore_cfg, - train_state=train_state, - partitioner=mock_partitioner, - ds_iter=mock.Mock(), - model_dir=directory, - ) - - self.assertIsInstance(manager, checkpoints.OrbaxCheckpointManagerInterface) - self.assertEqual(manager._manager._options.max_to_keep, 5) - self.assertEqual(manager._manager._options.save_interval_steps, 2) - self.assertEqual(manager._save_dtype, "float32") - self.assertEqual(manager._restore_dtype, "bfloat16") - self.assertTrue(manager._should_write_dataset_ckpt) - - # Save best options. - self.assertEqual( - manager._manager._options.keep_period, expected_force_keep_period - ) - self.assertEqual(manager._manager._options.best_mode, expected_metric_mode) - self.assertEqual( - manager._manager._options.keep_checkpoints_without_metrics, - expected_keep_checkpoints_without_metrics, - ) - if expected_metric is None: - self.assertIsNone(manager._manager._options.best_fn) - else: - self.assertEqual( - manager._manager._options.best_fn(metrics), - expected_metric, - ) - - @parameterized.parameters((True,), (False,)) - def test_create_orbax_checkpoint_manager_from_checkpointer( - self, set_save_checkpointer_cls - ): - directory = self.create_tempdir(name="all_checkpoints") - path = os.path.join(directory, "checkpoint") - mock_data_layout = mock.Mock( - shard_id=0, num_shards=1, is_first_host_in_replica_set=True - ) - mock_partitioner = mock.Mock(get_data_layout=lambda: mock_data_layout) - save_checkpointer_cls = ( - checkpoints.SaveBestCheckpointer - if set_save_checkpointer_cls - else checkpoints.Checkpointer - ) - restore_checkpointer_cls = ( - checkpoints.Checkpointer - if set_save_checkpointer_cls - else checkpoints.SaveBestCheckpointer - ) - save_cfg = utils.SaveCheckpointConfig( - checkpointer_cls=save_checkpointer_cls, - dtype="float32", - keep=5, - period=2, - save_dataset=True, - ) - restore_cfg = utils.RestoreCheckpointConfig( - checkpointer_cls=restore_checkpointer_cls, - path=path, - dtype="bfloat16", - restore_dataset=False, - ) - - global_mesh = test_utils.create_global_mesh((4, 2), ("x", "y")) - mesh_axes = partitioning.PartitionSpec("x", "y") - global_input_shape = (8, 2) - - train_state = test_utils.make_train_state( - global_mesh, global_input_shape, mesh_axes - ) - - manager = utils.create_orbax_checkpoint_manager( - save_cfg=save_cfg, - restore_cfg=restore_cfg, - train_state=train_state, - partitioner=mock_partitioner, - ds_iter=mock.Mock(), - model_dir=directory, - ) - - self.assertIsInstance(manager, checkpoints.OrbaxCheckpointManagerInterface) - self.assertEqual(manager._manager._options.max_to_keep, 5) - self.assertEqual(manager._manager._options.save_interval_steps, 2) - self.assertEqual(manager._save_dtype, "float32") - self.assertEqual(manager._restore_dtype, "bfloat16") - self.assertTrue(manager._should_write_dataset_ckpt) - - # Save best options. - self.assertIsNone(manager._manager._options.keep_period, None) - self.assertEqual(manager._manager._options.best_mode, "max") - self.assertTrue(manager._manager._options.keep_checkpoints_without_metrics) - self.assertEqual( - manager._manager._options.best_fn( - {"train/accuracy": 0.8, "train/loss": 0.1} - ), - 0.8, - ) - - -@dataclasses.dataclass -class MockTrainState: - path: Optional[str] = None - from_scratch: Optional[bool] = None - - -class MockCheckpointer(checkpoints.Checkpointer): - - def __init__(self, *args, **kwargs): - pass - - # restore should return TrainState, but we force it to return Mock with path - # for simplicity. - def restore(self, path, *args, **kwargs): - return MockTrainState(path=path, from_scratch=False) - - -class MockCheckpointManager(checkpoints.OrbaxCheckpointManagerInterface): - - def __init__(self, *args, **kwargs): - pass - - def restore(self, path, *args, **kwargs): - return MockTrainState(path=path, from_scratch=False) - - -class OrbaxRestoreTest(parameterized.TestCase): - - def setUp(self): - super().setUp() - - self.ckptdir = self.create_tempdir(name="primary_checkpoints") - steps = (2, 3) - self.paths = [] - for s in steps: - step_dir = self.ckptdir.mkdir(f"checkpoint_{s}") - step_dir.create_file("checkpoint") - self.paths += [step_dir.full_path] - - def test_orbax_restore(self): - # Properties of config not needed in this test. - restore_cfg = utils.RestoreCheckpointConfig(path=[]) - - manager = MockCheckpointManager() - restored = utils.restore(manager, [self.paths[0]], restore_cfg) - self.assertIsInstance(restored, MockTrainState) - self.assertEqual(restored.path, self.paths[0]) - - restored = utils.restore(manager, self.paths, restore_cfg) - self.assertIsInstance(restored, list) - self.assertSequenceEqual([state.path for state in restored], self.paths) - - -class TrainStateInitializerTest(parameterized.TestCase): - - def setUp(self): - super().setUp() - - def _partition(train_state, in_axis_resources, out_axis_resources): - del train_state, in_axis_resources, out_axis_resources - partitioned_fn = lambda _: MockTrainState(from_scratch=True) - return partitioned_fn - - partitioner = mock.Mock(get_mesh_axes=lambda _: None, partition=_partition) - mock_inference_state_create = self.enter_context( - mock.patch.object(train_state_lib.InferenceState, "create") - ) - mock_inference_state_create.return_value = None - - shapes = { - "ones": (1, 1), - "twos": (2, 2), - "threes": (3, 3), - } - types = { - "ones": int, - "twos": float, - "threes": int, - } - - def _init_fn(rng, input_shapes, input_types): - del rng - return { - "ones": np.ones(input_shapes["ones"], dtype=input_types["ones"]), - "twos": np.ones(input_shapes["twos"], dtype=input_types["twos"]) * 2, - "threes": ( - np.ones(input_shapes["threes"], dtype=input_types["threes"]) * 3 - ), - } - - init_fn = mock.Mock(side_effect=_init_fn) - init_fn.__self__ = None - - self.train_state_init = utils.TrainStateInitializer( - None, init_fn, shapes, partitioner, types - ) - - self.ckptdir = self.create_tempdir(name="primary_checkpoints") - steps = (2, 3) - self.paths = [] - for s in steps: - step_dir = self.ckptdir.mkdir(f"checkpoint_{s}") - step_dir.create_file("checkpoint") - self.paths += [step_dir.full_path] - - def test_from_checkpoints_specific(self): - # multiple paths - ckpt_cfg = utils.RestoreCheckpointConfig( - path=self.paths, mode="specific", checkpointer_cls=MockCheckpointer - ) - restored = list(self.train_state_init.from_checkpoints([ckpt_cfg])) - self.assertSequenceEqual(self.paths, [state.path for state, _ in restored]) - self.assertSequenceEqual(self.paths, [path for _, path in restored]) - with self.assertRaisesRegex(ValueError, r"^Expected at most 1 checkpoint"): - self.train_state_init.from_checkpoint([ckpt_cfg]) - - def test_from_checkpoints_latest(self): - # only restore single latest - ckpt_cfg = utils.RestoreCheckpointConfig( - path=self.ckptdir.full_path, - mode="latest", - checkpointer_cls=MockCheckpointer, - ) - restored = list(self.train_state_init.from_checkpoints([ckpt_cfg])) - assert len(restored) == 1 - restored_state, restored_path = restored[0] - self.assertEqual(self.paths[-1], restored_state.path) - self.assertEqual(self.paths[-1], restored_path) - restored_state = self.train_state_init.from_checkpoint([ckpt_cfg]) - self.assertEqual(self.paths[-1], restored_state.path) - - def test_from_checkpoint_multiple_configs(self): - # uses first checkpoint with files present. - ckpt_cfg = utils.RestoreCheckpointConfig( - path=self.ckptdir.full_path, - mode="latest", - checkpointer_cls=MockCheckpointer, - ) - secondary_ckptdir = self.create_tempdir(name="secondary_checkpoints") - for s in (4, 5): - step_dir = secondary_ckptdir.mkdir(f"checkpoint_{s}") - step_dir.create_file("checkpoint") - secondary_ckpt_cfg = utils.RestoreCheckpointConfig( - path=secondary_ckptdir.full_path, - mode="latest", - checkpointer_cls=MockCheckpointer, - ) - restored = self.train_state_init.from_checkpoint( - [ckpt_cfg, secondary_ckpt_cfg] - ) - self.assertEqual(self.paths[-1], restored.path) - - def test_from_checkpoint_multiple_configs_one_empty(self): - # skips empty_checkpoints directory with no checkpoints present. - ckpt_cfg = utils.RestoreCheckpointConfig( - path=self.ckptdir.full_path, - mode="latest", - checkpointer_cls=MockCheckpointer, - ) - empty_ckptdir = self.create_tempdir(name="empty_checkpoints") - empty_ckpt_cfg = utils.RestoreCheckpointConfig( - path=empty_ckptdir.full_path, - mode="latest", - checkpointer_cls=MockCheckpointer, - ) - restored = self.train_state_init.from_checkpoint([empty_ckpt_cfg, ckpt_cfg]) - self.assertEqual(self.paths[-1], restored.path) - - def test_from_scratch(self): - self.assertTrue( - self.train_state_init.from_scratch(jax.random.PRNGKey(13)).from_scratch - ) - - def test_from_checkpoint_or_scratch(self): - ckpt_cfg = utils.RestoreCheckpointConfig( - path=self.ckptdir.full_path, - mode="latest", - checkpointer_cls=MockCheckpointer, - ) - empty_ckptdir = self.create_tempdir(name="empty_checkpoints") - empty_ckpt_cfg = utils.RestoreCheckpointConfig( - path=empty_ckptdir.full_path, - mode="latest", - checkpointer_cls=MockCheckpointer, - ) - - init_rng = jax.random.PRNGKey(13) - - # ckpt_cfg has checkpoints, restore from there - restored = self.train_state_init.from_checkpoint_or_scratch( - [empty_ckpt_cfg, ckpt_cfg], init_rng=init_rng - ) - self.assertEqual(self.paths[-1], restored.path) - self.assertFalse(restored.from_scratch) - - # no checkpoints available, init from scratch - initialized = self.train_state_init.from_checkpoint_or_scratch( - [empty_ckpt_cfg], init_rng=init_rng - ) - self.assertTrue(initialized.from_scratch) - - -class MutableGetInferFnTest(parameterized.TestCase): - - @parameterized.parameters((True,), (False,)) - def test_get_infer_fn_predict(self, use_flax_mutables): - batch_size = 2 - weight_const = 7 - mutable_const = 5 if use_flax_mutables else 0 - global_mesh = test_utils.create_global_mesh( - (1, 1, 1), ("model", "data", "chips") - ) - - def predict_batch(params, batch, flax_mutables): - result = ( - flax_mutables["mutable_w"] * batch["b"] if use_flax_mutables else 0 - ) - return result + (params["weight"] * batch["a"]) - - def get_data_layout(batch_size): - assert batch_size == 2 # total batch size, not per-shard - assert jax.process_index() == 0 - return partitioning.DataLayout( - batch_size=2, - shard_id=0, - num_shards=1, - is_first_host_in_replica_set=True, - ) - - def partition(fn, in_axis_resources, out_axis_resources): - fn = pjit( - fn, - in_shardings=in_axis_resources, - out_shardings=out_axis_resources, - ) - return partitioning.PjittedFnWithContext(fn, global_mesh) - - if use_flax_mutables: - flax_mutables_train_state = {"mutable_w": np.full((1,), mutable_const)} - flax_mutables_train_state_axes = {"mutable_w": PartitionSpec("model")} - else: - flax_mutables_train_state = None - flax_mutables_train_state_axes = None - - train_state = get_mock_train_state( - params={"weight": np.full((1,), weight_const)}, - flax_mutables=flax_mutables_train_state, - ) - train_state_axes = get_mock_train_state( - params={"weight": PartitionSpec("model")}, - flax_mutables=flax_mutables_train_state_axes, - ) - - partitioner = mock.Mock() - partitioner.data_partition_spec = PartitionSpec( - "data", - ) - partitioner.get_data_layout = get_data_layout - partitioner.partition = partition - partitioner.mesh = global_mesh - - def as_sharded_array(arr, axes, mesh=None): - with mesh: - return pjit(lambda x: x, in_shardings=None, out_shardings=axes)(arr) - - train_state.params = jax.tree_util.tree_map( - functools.partial(as_sharded_array, mesh=global_mesh), - train_state.params, - train_state_axes.params, - ) - - infer_fn = utils.get_infer_fn( - predict_batch, batch_size, train_state_axes, partitioner=partitioner - ) - - ds = tf.data.Dataset.from_tensor_slices({ - "a": np.transpose(np.tile(np.arange(8), (2, 1))), - "b": np.transpose(np.tile(np.arange(8), (2, 1))), - }).enumerate() - - all_indices, all_inferences = zip(*infer_fn(ds, train_state)) - - np.testing.assert_equal(all_indices, np.arange(8)) - np.testing.assert_equal( - all_inferences, - np.transpose( - np.tile(np.arange(8) * (weight_const + mutable_const), (2, 1)) - ), - ) - - -if __name__ == "__main__": - absltest.main() diff --git a/t5x-main/t5x/version.py b/t5x-main/t5x/version.py deleted file mode 100644 index 9b7793020a1ef43045af25f47613eb7f43cc8e3b..0000000000000000000000000000000000000000 --- a/t5x-main/t5x/version.py +++ /dev/null @@ -1,21 +0,0 @@ -# Copyright 2024 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -r"""Separate file for storing the current version of T5X. - -Stored in a separate file so that setup.py can reference the version without -pulling in all the dependencies in __init__.py. -""" - -__version__ = '0.0.0'